This commit is contained in:
Manfred Touron
2017-05-18 18:54:23 +02:00
parent dc386661ca
commit 5448f25fd6
645 changed files with 55908 additions and 33297 deletions

View File

@@ -1,50 +0,0 @@
package test
import (
"context"
"google.golang.org/grpc"
"github.com/go-kit/kit/endpoint"
grpctransport "github.com/go-kit/kit/transport/grpc"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
type clientBinding struct {
test endpoint.Endpoint
}
func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
response, err := c.test(ctx, TestRequest{A: a, B: b})
if err != nil {
return nil, "", err
}
r := response.(*TestResponse)
return r.Ctx, r.V, nil
}
func NewClient(cc *grpc.ClientConn) Service {
return &clientBinding{
test: grpctransport.NewClient(
cc,
"pb.Test",
"Test",
encodeRequest,
decodeResponse,
&pb.TestResponse{},
grpctransport.ClientBefore(
injectCorrelationID,
),
grpctransport.ClientBefore(
displayClientRequestHeaders,
),
grpctransport.ClientAfter(
displayClientResponseHeaders,
displayClientResponseTrailers,
),
grpctransport.ClientAfter(
extractConsumedCorrelationID,
),
).Endpoint(),
}
}

View File

@@ -1,141 +0,0 @@
package test
import (
"context"
"fmt"
"google.golang.org/grpc/metadata"
)
type metaContext string
const (
correlationID metaContext = "correlation-id"
responseHDR metaContext = "my-response-header"
responseTRLR metaContext = "my-response-trailer"
correlationIDTRLR metaContext = "correlation-id-consumed"
)
/* client before functions */
func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr)
(*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr)
}
return ctx
}
func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tClient >> Request Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
/* server before functions */
func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationID)]; ok {
cID := hdr[len(hdr)-1]
ctx = context.WithValue(ctx, correlationID, cID)
fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID)
}
return ctx
}
func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tServer << Request Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
/* server after functions */
func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value"))
return ctx
}
func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too"))
return ctx
}
func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr)
*md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr))
}
return ctx
}
func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Trailers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
/* client after functions */
func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Trailers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationIDTRLR)]; ok {
fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1])
ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1])
}
return ctx
}
/* CorrelationID context handlers */
func SetCorrelationID(ctx context.Context, v string) context.Context {
return context.WithValue(ctx, correlationID, v)
}
func GetConsumedCorrelationID(ctx context.Context) string {
if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok {
return trlr
}
return ""
}

View File

@@ -1,3 +0,0 @@
package pb
//go:generate protoc test.proto --go_out=plugins=grpc:.

View File

@@ -1,167 +0,0 @@
// Code generated by protoc-gen-go.
// source: test.proto
// DO NOT EDIT!
/*
Package pb is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
TestRequest
TestResponse
*/
package pb
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type TestRequest struct {
A string `protobuf:"bytes,1,opt,name=a" json:"a,omitempty"`
B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"`
}
func (m *TestRequest) Reset() { *m = TestRequest{} }
func (m *TestRequest) String() string { return proto.CompactTextString(m) }
func (*TestRequest) ProtoMessage() {}
func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *TestRequest) GetA() string {
if m != nil {
return m.A
}
return ""
}
func (m *TestRequest) GetB() int64 {
if m != nil {
return m.B
}
return 0
}
type TestResponse struct {
V string `protobuf:"bytes,1,opt,name=v" json:"v,omitempty"`
}
func (m *TestResponse) Reset() { *m = TestResponse{} }
func (m *TestResponse) String() string { return proto.CompactTextString(m) }
func (*TestResponse) ProtoMessage() {}
func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *TestResponse) GetV() string {
if m != nil {
return m.V
}
return ""
}
func init() {
proto.RegisterType((*TestRequest)(nil), "pb.TestRequest")
proto.RegisterType((*TestResponse)(nil), "pb.TestResponse")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Test service
type TestClient interface {
Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error)
}
type testClient struct {
cc *grpc.ClientConn
}
func NewTestClient(cc *grpc.ClientConn) TestClient {
return &testClient{cc}
}
func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) {
out := new(TestResponse)
err := grpc.Invoke(ctx, "/pb.Test/Test", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Test service
type TestServer interface {
Test(context.Context, *TestRequest) (*TestResponse, error)
}
func RegisterTestServer(s *grpc.Server, srv TestServer) {
s.RegisterService(&_Test_serviceDesc, srv)
}
func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(TestRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(TestServer).Test(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/pb.Test/Test",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(TestServer).Test(ctx, req.(*TestRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Test_serviceDesc = grpc.ServiceDesc{
ServiceName: "pb.Test",
HandlerType: (*TestServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Test",
Handler: _Test_Test_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "test.proto",
}
func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 129 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe4, 0xe2, 0x0e, 0x49,
0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60,
0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x98,
0x83, 0x18, 0x93, 0x94, 0x64, 0xb8, 0x78, 0x20, 0x4a, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41,
0xb2, 0x65, 0x30, 0xb5, 0x65, 0x46, 0xc6, 0x5c, 0x2c, 0x20, 0x59, 0x21, 0x6d, 0x28, 0xcd, 0xaf,
0x57, 0x90, 0xa4, 0x87, 0x64, 0xb4, 0x94, 0x00, 0x42, 0x00, 0x62, 0x80, 0x12, 0x43, 0x12, 0x1b,
0xd8, 0x21, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x49, 0xfc, 0xd8, 0xf1, 0x96, 0x00, 0x00,
0x00,
}

View File

@@ -1,16 +0,0 @@
syntax = "proto3";
package pb;
service Test {
rpc Test (TestRequest) returns (TestResponse) {}
}
message TestRequest {
string a = 1;
int64 b = 2;
}
message TestResponse {
string v = 1;
}

View File

@@ -1,27 +0,0 @@
package test
import (
"context"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) {
r := req.(TestRequest)
return &pb.TestRequest{A: r.A, B: r.B}, nil
}
func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) {
r := req.(*pb.TestRequest)
return TestRequest{A: r.A, B: r.B}, nil
}
func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) {
r := resp.(*TestResponse)
return &pb.TestResponse{V: r.V}, nil
}
func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) {
r := resp.(*pb.TestResponse)
return &TestResponse{V: r.V, Ctx: ctx}, nil
}

View File

@@ -1,70 +0,0 @@
package test
import (
"context"
"fmt"
oldcontext "golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
grpctransport "github.com/go-kit/kit/transport/grpc"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
type service struct{}
func (service) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
return nil, fmt.Sprintf("%s = %d", a, b), nil
}
func NewService() Service {
return service{}
}
func makeTestEndpoint(svc Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(TestRequest)
newCtx, v, err := svc.Test(ctx, req.A, req.B)
return &TestResponse{
V: v,
Ctx: newCtx,
}, err
}
}
type serverBinding struct {
test grpctransport.Handler
}
func (b *serverBinding) Test(ctx oldcontext.Context, req *pb.TestRequest) (*pb.TestResponse, error) {
_, response, err := b.test.ServeGRPC(ctx, req)
if err != nil {
return nil, err
}
return response.(*pb.TestResponse), nil
}
func NewBinding(svc Service) *serverBinding {
return &serverBinding{
test: grpctransport.NewServer(
makeTestEndpoint(svc),
decodeRequest,
encodeResponse,
grpctransport.ServerBefore(
extractCorrelationID,
),
grpctransport.ServerBefore(
displayServerRequestHeaders,
),
grpctransport.ServerAfter(
injectResponseHeader,
injectResponseTrailer,
injectConsumedCorrelationID,
),
grpctransport.ServerAfter(
displayServerResponseHeaders,
displayServerResponseTrailers,
),
),
}
}

View File

@@ -1,17 +0,0 @@
package test
import "context"
type Service interface {
Test(ctx context.Context, a string, b int64) (context.Context, string, error)
}
type TestRequest struct {
A string
B int64
}
type TestResponse struct {
Ctx context.Context
V string
}

View File

@@ -1,11 +1,11 @@
package grpc
import (
"context"
"fmt"
"reflect"
"strings"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@@ -21,8 +21,7 @@ type Client struct {
enc EncodeRequestFunc
dec DecodeResponseFunc
grpcReply reflect.Type
before []ClientRequestFunc
after []ClientResponseFunc
before []RequestFunc
}
// NewClient constructs a usable Client for a single remote endpoint.
@@ -54,8 +53,7 @@ func NewClient(
reflect.ValueOf(grpcReply),
).Interface(),
),
before: []ClientRequestFunc{},
after: []ClientResponseFunc{},
before: []RequestFunc{},
}
for _, option := range options {
option(c)
@@ -68,15 +66,8 @@ type ClientOption func(*Client)
// ClientBefore sets the RequestFuncs that are applied to the outgoing gRPC
// request before it's invoked.
func ClientBefore(before ...ClientRequestFunc) ClientOption {
return func(c *Client) { c.before = append(c.before, before...) }
}
// ClientAfter sets the ClientResponseFuncs that are applied to the incoming
// gRPC response prior to it being decoded. This is useful for obtaining
// response metadata and adding onto the context prior to decoding.
func ClientAfter(after ...ClientResponseFunc) ClientOption {
return func(c *Client) { c.after = append(c.after, after...) }
func ClientBefore(before ...RequestFunc) ClientOption {
return func(c *Client) { c.before = before }
}
// Endpoint returns a usable endpoint that will invoke the gRPC specified by the
@@ -88,7 +79,7 @@ func (c Client) Endpoint() endpoint.Endpoint {
req, err := c.enc(ctx, request)
if err != nil {
return nil, err
return nil, fmt.Errorf("Encode: %v", err)
}
md := &metadata.MD{}
@@ -97,22 +88,14 @@ func (c Client) Endpoint() endpoint.Endpoint {
}
ctx = metadata.NewContext(ctx, *md)
var header, trailer metadata.MD
grpcReply := reflect.New(c.grpcReply).Interface()
if err = grpc.Invoke(
ctx, c.method, req, grpcReply, c.client,
grpc.Header(&header), grpc.Trailer(&trailer),
); err != nil {
return nil, err
}
for _, f := range c.after {
ctx = f(ctx, header, trailer)
if err = grpc.Invoke(ctx, c.method, req, grpcReply, c.client); err != nil {
return nil, fmt.Errorf("Invoke: %v", err)
}
response, err := c.dec(ctx, grpcReply)
if err != nil {
return nil, err
return nil, fmt.Errorf("Decode: %v", err)
}
return response, nil
}

View File

@@ -1,59 +0,0 @@
package grpc_test
import (
"context"
"fmt"
"net"
"testing"
"google.golang.org/grpc"
test "github.com/go-kit/kit/transport/grpc/_grpc_test"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
const (
hostPort string = "localhost:8002"
)
func TestGRPCClient(t *testing.T) {
var (
server = grpc.NewServer()
service = test.NewService()
)
sc, err := net.Listen("tcp", hostPort)
if err != nil {
t.Fatalf("unable to listen: %+v", err)
}
defer server.GracefulStop()
go func() {
pb.RegisterTestServer(server, test.NewBinding(service))
_ = server.Serve(sc)
}()
cc, err := grpc.Dial(hostPort, grpc.WithInsecure())
if err != nil {
t.Fatalf("unable to Dial: %+v", err)
}
client := test.NewClient(cc)
var (
a = "the answer to life the universe and everything"
b = int64(42)
cID = "request-1"
ctx = test.SetCorrelationID(context.Background(), cID)
)
responseCTX, v, err := client.Test(ctx, a, b)
if want, have := fmt.Sprintf("%s = %d", a, b), v; want != have {
t.Fatalf("want %q, have %q", want, have)
}
if want, have := cID, test.GetConsumedCorrelationID(responseCTX); want != have {
t.Fatalf("want %q, have %q", want, have)
}
}

View File

@@ -1,25 +1,23 @@
package grpc
import (
"context"
)
import "golang.org/x/net/context"
// DecodeRequestFunc extracts a user-domain request object from a gRPC request.
// It's designed to be used in gRPC servers, for server-side endpoints. One
// straightforward DecodeRequestFunc could be something that decodes from the
// gRPC request message to the concrete request type.
// straightforward DecodeRequestFunc could be something that
// decodes from the gRPC request message to the concrete request type.
type DecodeRequestFunc func(context.Context, interface{}) (request interface{}, err error)
// EncodeRequestFunc encodes the passed request object into the gRPC request
// object. It's designed to be used in gRPC clients, for client-side endpoints.
// One straightforward EncodeRequestFunc could something that encodes the object
// directly to the gRPC request message.
// object. It's designed to be used in gRPC clients, for client-side
// endpoints. One straightforward EncodeRequestFunc could something that
// encodes the object directly to the gRPC request message.
type EncodeRequestFunc func(context.Context, interface{}) (request interface{}, err error)
// EncodeResponseFunc encodes the passed response object to the gRPC response
// message. It's designed to be used in gRPC servers, for server-side endpoints.
// One straightforward EncodeResponseFunc could be something that encodes the
// object directly to the gRPC response message.
// message. It's designed to be used in gRPC servers, for server-side
// endpoints. One straightforward EncodeResponseFunc could be something that
// encodes the object directly to the gRPC response message.
type EncodeResponseFunc func(context.Context, interface{}) (response interface{}, err error)
// DecodeResponseFunc extracts a user-domain response object from a gRPC

View File

@@ -1,10 +1,10 @@
package grpc
import (
"context"
"encoding/base64"
"strings"
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"
)
@@ -12,53 +12,30 @@ const (
binHdrSuffix = "-bin"
)
// ClientRequestFunc may take information from context and use it to construct
// metadata headers to be transported to the server. ClientRequestFuncs are
// executed after creating the request but prior to sending the gRPC request to
// the server.
type ClientRequestFunc func(context.Context, *metadata.MD) context.Context
// RequestFunc may take information from an gRPC request and put it into a
// request context. In Servers, BeforeFuncs are executed prior to invoking the
// endpoint. In Clients, BeforeFuncs are executed after creating the request
// but prior to invoking the gRPC client.
type RequestFunc func(context.Context, *metadata.MD) context.Context
// ServerRequestFunc may take information from the received metadata header and
// use it to place items in the request scoped context. ServerRequestFuncs are
// executed prior to invoking the endpoint.
type ServerRequestFunc func(context.Context, metadata.MD) context.Context
// ServerResponseFunc may take information from a request context and use it to
// manipulate the gRPC response metadata headers and trailers. ResponseFuncs are
// only executed in servers, after invoking the endpoint but prior to writing a
// response.
type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context
// ClientResponseFunc may take information from a gRPC metadata header and/or
// trailer and make the responses available for consumption. ClientResponseFuncs
// are only executed in clients, after a request has been made, but prior to it
// being decoded.
type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context
// SetRequestHeader returns a ClientRequestFunc that sets the specified metadata
// key-value pair.
func SetRequestHeader(key, val string) ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
key, val := EncodeKeyValue(key, val)
(*md)[key] = append((*md)[key], val)
return ctx
}
}
// ResponseFunc may take information from a request context and use it to
// manipulate the gRPC metadata header. ResponseFuncs are only executed in
// servers, after invoking the endpoint but prior to writing a response.
type ResponseFunc func(context.Context, *metadata.MD)
// SetResponseHeader returns a ResponseFunc that sets the specified metadata
// key-value pair.
func SetResponseHeader(key, val string) ServerResponseFunc {
return func(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
func SetResponseHeader(key, val string) ResponseFunc {
return func(_ context.Context, md *metadata.MD) {
key, val := EncodeKeyValue(key, val)
(*md)[key] = append((*md)[key], val)
return ctx
}
}
// SetResponseTrailer returns a ResponseFunc that sets the specified metadata
// SetRequestHeader returns a RequestFunc that sets the specified metadata
// key-value pair.
func SetResponseTrailer(key, val string) ServerResponseFunc {
return func(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
func SetRequestHeader(key, val string) RequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
key, val := EncodeKeyValue(key, val)
(*md)[key] = append((*md)[key], val)
return ctx

View File

@@ -1,28 +1,28 @@
package grpc
import (
oldcontext "golang.org/x/net/context"
"google.golang.org/grpc"
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
// Handler which should be called from the gRPC binding of the service
// Handler which should be called from the grpc binding of the service
// implementation. The incoming request parameter, and returned response
// parameter, are both gRPC types, not user-domain.
type Handler interface {
ServeGRPC(ctx oldcontext.Context, request interface{}) (oldcontext.Context, interface{}, error)
ServeGRPC(ctx context.Context, request interface{}) (context.Context, interface{}, error)
}
// Server wraps an endpoint and implements grpc.Handler.
type Server struct {
ctx context.Context
e endpoint.Endpoint
dec DecodeRequestFunc
enc EncodeResponseFunc
before []ServerRequestFunc
after []ServerResponseFunc
before []RequestFunc
after []ResponseFunc
logger log.Logger
}
@@ -32,12 +32,14 @@ type Server struct {
// definitions to individual handlers. Request and response objects are from the
// caller business domain, not gRPC request and reply types.
func NewServer(
ctx context.Context,
e endpoint.Endpoint,
dec DecodeRequestFunc,
enc EncodeResponseFunc,
options ...ServerOption,
) *Server {
s := &Server{
ctx: ctx,
e: e,
dec: dec,
enc: enc,
@@ -54,14 +56,14 @@ type ServerOption func(*Server)
// ServerBefore functions are executed on the HTTP request object before the
// request is decoded.
func ServerBefore(before ...ServerRequestFunc) ServerOption {
return func(s *Server) { s.before = append(s.before, before...) }
func ServerBefore(before ...RequestFunc) ServerOption {
return func(s *Server) { s.before = before }
}
// ServerAfter functions are executed on the HTTP response writer after the
// endpoint is invoked, but before anything is written to the client.
func ServerAfter(after ...ServerResponseFunc) ServerOption {
return func(s *Server) { s.after = append(s.after, after...) }
func ServerAfter(after ...ResponseFunc) ServerOption {
return func(s *Server) { s.after = after }
}
// ServerErrorLogger is used to log non-terminal errors. By default, no errors
@@ -71,53 +73,56 @@ func ServerErrorLogger(logger log.Logger) ServerOption {
}
// ServeGRPC implements the Handler interface.
func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) {
func (s Server) ServeGRPC(grpcCtx context.Context, req interface{}) (context.Context, interface{}, error) {
ctx := s.ctx
// Retrieve gRPC metadata.
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromContext(grpcCtx)
if !ok {
md = metadata.MD{}
}
for _, f := range s.before {
ctx = f(ctx, md)
ctx = f(ctx, &md)
}
request, err := s.dec(ctx, req)
// Store potentially updated metadata in the gRPC context.
grpcCtx = metadata.NewContext(grpcCtx, md)
request, err := s.dec(grpcCtx, req)
if err != nil {
s.logger.Log("err", err)
return ctx, nil, err
return grpcCtx, nil, BadRequestError{err}
}
response, err := s.e(ctx, request)
if err != nil {
s.logger.Log("err", err)
return ctx, nil, err
return grpcCtx, nil, err
}
var mdHeader, mdTrailer metadata.MD
for _, f := range s.after {
ctx = f(ctx, &mdHeader, &mdTrailer)
f(ctx, &md)
}
grpcResp, err := s.enc(ctx, response)
// Store potentially updated metadata in the gRPC context.
grpcCtx = metadata.NewContext(grpcCtx, md)
grpcResp, err := s.enc(grpcCtx, response)
if err != nil {
s.logger.Log("err", err)
return ctx, nil, err
return grpcCtx, nil, err
}
if len(mdHeader) > 0 {
if err = grpc.SendHeader(ctx, mdHeader); err != nil {
s.logger.Log("err", err)
return ctx, nil, err
}
}
if len(mdTrailer) > 0 {
if err = grpc.SetTrailer(ctx, mdTrailer); err != nil {
s.logger.Log("err", err)
return ctx, nil, err
}
}
return ctx, grpcResp, nil
return grpcCtx, grpcResp, nil
}
// BadRequestError is an error in decoding the request.
type BadRequestError struct {
Err error
}
// Error implements the error interface.
func (err BadRequestError) Error() string {
return err.Err.Error()
}