This commit is contained in:
Manfred Touron
2017-05-18 18:54:23 +02:00
parent dc386661ca
commit 5448f25fd6
645 changed files with 55908 additions and 33297 deletions

View File

@@ -1,50 +0,0 @@
package test
import (
"context"
"google.golang.org/grpc"
"github.com/go-kit/kit/endpoint"
grpctransport "github.com/go-kit/kit/transport/grpc"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
type clientBinding struct {
test endpoint.Endpoint
}
func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
response, err := c.test(ctx, TestRequest{A: a, B: b})
if err != nil {
return nil, "", err
}
r := response.(*TestResponse)
return r.Ctx, r.V, nil
}
func NewClient(cc *grpc.ClientConn) Service {
return &clientBinding{
test: grpctransport.NewClient(
cc,
"pb.Test",
"Test",
encodeRequest,
decodeResponse,
&pb.TestResponse{},
grpctransport.ClientBefore(
injectCorrelationID,
),
grpctransport.ClientBefore(
displayClientRequestHeaders,
),
grpctransport.ClientAfter(
displayClientResponseHeaders,
displayClientResponseTrailers,
),
grpctransport.ClientAfter(
extractConsumedCorrelationID,
),
).Endpoint(),
}
}

View File

@@ -1,141 +0,0 @@
package test
import (
"context"
"fmt"
"google.golang.org/grpc/metadata"
)
type metaContext string
const (
correlationID metaContext = "correlation-id"
responseHDR metaContext = "my-response-header"
responseTRLR metaContext = "my-response-trailer"
correlationIDTRLR metaContext = "correlation-id-consumed"
)
/* client before functions */
func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr)
(*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr)
}
return ctx
}
func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tClient >> Request Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
/* server before functions */
func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationID)]; ok {
cID := hdr[len(hdr)-1]
ctx = context.WithValue(ctx, correlationID, cID)
fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID)
}
return ctx
}
func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tServer << Request Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
/* server after functions */
func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value"))
return ctx
}
func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too"))
return ctx
}
func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr)
*md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr))
}
return ctx
}
func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Trailers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
/* client after functions */
func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Trailers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}
func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationIDTRLR)]; ok {
fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1])
ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1])
}
return ctx
}
/* CorrelationID context handlers */
func SetCorrelationID(ctx context.Context, v string) context.Context {
return context.WithValue(ctx, correlationID, v)
}
func GetConsumedCorrelationID(ctx context.Context) string {
if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok {
return trlr
}
return ""
}

View File

@@ -1,3 +0,0 @@
package pb
//go:generate protoc test.proto --go_out=plugins=grpc:.

View File

@@ -1,167 +0,0 @@
// Code generated by protoc-gen-go.
// source: test.proto
// DO NOT EDIT!
/*
Package pb is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
TestRequest
TestResponse
*/
package pb
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type TestRequest struct {
A string `protobuf:"bytes,1,opt,name=a" json:"a,omitempty"`
B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"`
}
func (m *TestRequest) Reset() { *m = TestRequest{} }
func (m *TestRequest) String() string { return proto.CompactTextString(m) }
func (*TestRequest) ProtoMessage() {}
func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *TestRequest) GetA() string {
if m != nil {
return m.A
}
return ""
}
func (m *TestRequest) GetB() int64 {
if m != nil {
return m.B
}
return 0
}
type TestResponse struct {
V string `protobuf:"bytes,1,opt,name=v" json:"v,omitempty"`
}
func (m *TestResponse) Reset() { *m = TestResponse{} }
func (m *TestResponse) String() string { return proto.CompactTextString(m) }
func (*TestResponse) ProtoMessage() {}
func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *TestResponse) GetV() string {
if m != nil {
return m.V
}
return ""
}
func init() {
proto.RegisterType((*TestRequest)(nil), "pb.TestRequest")
proto.RegisterType((*TestResponse)(nil), "pb.TestResponse")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Test service
type TestClient interface {
Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error)
}
type testClient struct {
cc *grpc.ClientConn
}
func NewTestClient(cc *grpc.ClientConn) TestClient {
return &testClient{cc}
}
func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) {
out := new(TestResponse)
err := grpc.Invoke(ctx, "/pb.Test/Test", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Test service
type TestServer interface {
Test(context.Context, *TestRequest) (*TestResponse, error)
}
func RegisterTestServer(s *grpc.Server, srv TestServer) {
s.RegisterService(&_Test_serviceDesc, srv)
}
func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(TestRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(TestServer).Test(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/pb.Test/Test",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(TestServer).Test(ctx, req.(*TestRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Test_serviceDesc = grpc.ServiceDesc{
ServiceName: "pb.Test",
HandlerType: (*TestServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Test",
Handler: _Test_Test_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "test.proto",
}
func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 129 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe4, 0xe2, 0x0e, 0x49,
0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60,
0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x98,
0x83, 0x18, 0x93, 0x94, 0x64, 0xb8, 0x78, 0x20, 0x4a, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41,
0xb2, 0x65, 0x30, 0xb5, 0x65, 0x46, 0xc6, 0x5c, 0x2c, 0x20, 0x59, 0x21, 0x6d, 0x28, 0xcd, 0xaf,
0x57, 0x90, 0xa4, 0x87, 0x64, 0xb4, 0x94, 0x00, 0x42, 0x00, 0x62, 0x80, 0x12, 0x43, 0x12, 0x1b,
0xd8, 0x21, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x49, 0xfc, 0xd8, 0xf1, 0x96, 0x00, 0x00,
0x00,
}

View File

@@ -1,16 +0,0 @@
syntax = "proto3";
package pb;
service Test {
rpc Test (TestRequest) returns (TestResponse) {}
}
message TestRequest {
string a = 1;
int64 b = 2;
}
message TestResponse {
string v = 1;
}

View File

@@ -1,27 +0,0 @@
package test
import (
"context"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) {
r := req.(TestRequest)
return &pb.TestRequest{A: r.A, B: r.B}, nil
}
func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) {
r := req.(*pb.TestRequest)
return TestRequest{A: r.A, B: r.B}, nil
}
func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) {
r := resp.(*TestResponse)
return &pb.TestResponse{V: r.V}, nil
}
func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) {
r := resp.(*pb.TestResponse)
return &TestResponse{V: r.V, Ctx: ctx}, nil
}

View File

@@ -1,70 +0,0 @@
package test
import (
"context"
"fmt"
oldcontext "golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
grpctransport "github.com/go-kit/kit/transport/grpc"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
type service struct{}
func (service) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
return nil, fmt.Sprintf("%s = %d", a, b), nil
}
func NewService() Service {
return service{}
}
func makeTestEndpoint(svc Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(TestRequest)
newCtx, v, err := svc.Test(ctx, req.A, req.B)
return &TestResponse{
V: v,
Ctx: newCtx,
}, err
}
}
type serverBinding struct {
test grpctransport.Handler
}
func (b *serverBinding) Test(ctx oldcontext.Context, req *pb.TestRequest) (*pb.TestResponse, error) {
_, response, err := b.test.ServeGRPC(ctx, req)
if err != nil {
return nil, err
}
return response.(*pb.TestResponse), nil
}
func NewBinding(svc Service) *serverBinding {
return &serverBinding{
test: grpctransport.NewServer(
makeTestEndpoint(svc),
decodeRequest,
encodeResponse,
grpctransport.ServerBefore(
extractCorrelationID,
),
grpctransport.ServerBefore(
displayServerRequestHeaders,
),
grpctransport.ServerAfter(
injectResponseHeader,
injectResponseTrailer,
injectConsumedCorrelationID,
),
grpctransport.ServerAfter(
displayServerResponseHeaders,
displayServerResponseTrailers,
),
),
}
}

View File

@@ -1,17 +0,0 @@
package test
import "context"
type Service interface {
Test(ctx context.Context, a string, b int64) (context.Context, string, error)
}
type TestRequest struct {
A string
B int64
}
type TestResponse struct {
Ctx context.Context
V string
}

View File

@@ -1,11 +1,11 @@
package grpc
import (
"context"
"fmt"
"reflect"
"strings"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@@ -21,8 +21,7 @@ type Client struct {
enc EncodeRequestFunc
dec DecodeResponseFunc
grpcReply reflect.Type
before []ClientRequestFunc
after []ClientResponseFunc
before []RequestFunc
}
// NewClient constructs a usable Client for a single remote endpoint.
@@ -54,8 +53,7 @@ func NewClient(
reflect.ValueOf(grpcReply),
).Interface(),
),
before: []ClientRequestFunc{},
after: []ClientResponseFunc{},
before: []RequestFunc{},
}
for _, option := range options {
option(c)
@@ -68,15 +66,8 @@ type ClientOption func(*Client)
// ClientBefore sets the RequestFuncs that are applied to the outgoing gRPC
// request before it's invoked.
func ClientBefore(before ...ClientRequestFunc) ClientOption {
return func(c *Client) { c.before = append(c.before, before...) }
}
// ClientAfter sets the ClientResponseFuncs that are applied to the incoming
// gRPC response prior to it being decoded. This is useful for obtaining
// response metadata and adding onto the context prior to decoding.
func ClientAfter(after ...ClientResponseFunc) ClientOption {
return func(c *Client) { c.after = append(c.after, after...) }
func ClientBefore(before ...RequestFunc) ClientOption {
return func(c *Client) { c.before = before }
}
// Endpoint returns a usable endpoint that will invoke the gRPC specified by the
@@ -88,7 +79,7 @@ func (c Client) Endpoint() endpoint.Endpoint {
req, err := c.enc(ctx, request)
if err != nil {
return nil, err
return nil, fmt.Errorf("Encode: %v", err)
}
md := &metadata.MD{}
@@ -97,22 +88,14 @@ func (c Client) Endpoint() endpoint.Endpoint {
}
ctx = metadata.NewContext(ctx, *md)
var header, trailer metadata.MD
grpcReply := reflect.New(c.grpcReply).Interface()
if err = grpc.Invoke(
ctx, c.method, req, grpcReply, c.client,
grpc.Header(&header), grpc.Trailer(&trailer),
); err != nil {
return nil, err
}
for _, f := range c.after {
ctx = f(ctx, header, trailer)
if err = grpc.Invoke(ctx, c.method, req, grpcReply, c.client); err != nil {
return nil, fmt.Errorf("Invoke: %v", err)
}
response, err := c.dec(ctx, grpcReply)
if err != nil {
return nil, err
return nil, fmt.Errorf("Decode: %v", err)
}
return response, nil
}

View File

@@ -1,59 +0,0 @@
package grpc_test
import (
"context"
"fmt"
"net"
"testing"
"google.golang.org/grpc"
test "github.com/go-kit/kit/transport/grpc/_grpc_test"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)
const (
hostPort string = "localhost:8002"
)
func TestGRPCClient(t *testing.T) {
var (
server = grpc.NewServer()
service = test.NewService()
)
sc, err := net.Listen("tcp", hostPort)
if err != nil {
t.Fatalf("unable to listen: %+v", err)
}
defer server.GracefulStop()
go func() {
pb.RegisterTestServer(server, test.NewBinding(service))
_ = server.Serve(sc)
}()
cc, err := grpc.Dial(hostPort, grpc.WithInsecure())
if err != nil {
t.Fatalf("unable to Dial: %+v", err)
}
client := test.NewClient(cc)
var (
a = "the answer to life the universe and everything"
b = int64(42)
cID = "request-1"
ctx = test.SetCorrelationID(context.Background(), cID)
)
responseCTX, v, err := client.Test(ctx, a, b)
if want, have := fmt.Sprintf("%s = %d", a, b), v; want != have {
t.Fatalf("want %q, have %q", want, have)
}
if want, have := cID, test.GetConsumedCorrelationID(responseCTX); want != have {
t.Fatalf("want %q, have %q", want, have)
}
}

View File

@@ -1,25 +1,23 @@
package grpc
import (
"context"
)
import "golang.org/x/net/context"
// DecodeRequestFunc extracts a user-domain request object from a gRPC request.
// It's designed to be used in gRPC servers, for server-side endpoints. One
// straightforward DecodeRequestFunc could be something that decodes from the
// gRPC request message to the concrete request type.
// straightforward DecodeRequestFunc could be something that
// decodes from the gRPC request message to the concrete request type.
type DecodeRequestFunc func(context.Context, interface{}) (request interface{}, err error)
// EncodeRequestFunc encodes the passed request object into the gRPC request
// object. It's designed to be used in gRPC clients, for client-side endpoints.
// One straightforward EncodeRequestFunc could something that encodes the object
// directly to the gRPC request message.
// object. It's designed to be used in gRPC clients, for client-side
// endpoints. One straightforward EncodeRequestFunc could something that
// encodes the object directly to the gRPC request message.
type EncodeRequestFunc func(context.Context, interface{}) (request interface{}, err error)
// EncodeResponseFunc encodes the passed response object to the gRPC response
// message. It's designed to be used in gRPC servers, for server-side endpoints.
// One straightforward EncodeResponseFunc could be something that encodes the
// object directly to the gRPC response message.
// message. It's designed to be used in gRPC servers, for server-side
// endpoints. One straightforward EncodeResponseFunc could be something that
// encodes the object directly to the gRPC response message.
type EncodeResponseFunc func(context.Context, interface{}) (response interface{}, err error)
// DecodeResponseFunc extracts a user-domain response object from a gRPC

View File

@@ -1,10 +1,10 @@
package grpc
import (
"context"
"encoding/base64"
"strings"
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"
)
@@ -12,53 +12,30 @@ const (
binHdrSuffix = "-bin"
)
// ClientRequestFunc may take information from context and use it to construct
// metadata headers to be transported to the server. ClientRequestFuncs are
// executed after creating the request but prior to sending the gRPC request to
// the server.
type ClientRequestFunc func(context.Context, *metadata.MD) context.Context
// RequestFunc may take information from an gRPC request and put it into a
// request context. In Servers, BeforeFuncs are executed prior to invoking the
// endpoint. In Clients, BeforeFuncs are executed after creating the request
// but prior to invoking the gRPC client.
type RequestFunc func(context.Context, *metadata.MD) context.Context
// ServerRequestFunc may take information from the received metadata header and
// use it to place items in the request scoped context. ServerRequestFuncs are
// executed prior to invoking the endpoint.
type ServerRequestFunc func(context.Context, metadata.MD) context.Context
// ServerResponseFunc may take information from a request context and use it to
// manipulate the gRPC response metadata headers and trailers. ResponseFuncs are
// only executed in servers, after invoking the endpoint but prior to writing a
// response.
type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context
// ClientResponseFunc may take information from a gRPC metadata header and/or
// trailer and make the responses available for consumption. ClientResponseFuncs
// are only executed in clients, after a request has been made, but prior to it
// being decoded.
type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context
// SetRequestHeader returns a ClientRequestFunc that sets the specified metadata
// key-value pair.
func SetRequestHeader(key, val string) ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
key, val := EncodeKeyValue(key, val)
(*md)[key] = append((*md)[key], val)
return ctx
}
}
// ResponseFunc may take information from a request context and use it to
// manipulate the gRPC metadata header. ResponseFuncs are only executed in
// servers, after invoking the endpoint but prior to writing a response.
type ResponseFunc func(context.Context, *metadata.MD)
// SetResponseHeader returns a ResponseFunc that sets the specified metadata
// key-value pair.
func SetResponseHeader(key, val string) ServerResponseFunc {
return func(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
func SetResponseHeader(key, val string) ResponseFunc {
return func(_ context.Context, md *metadata.MD) {
key, val := EncodeKeyValue(key, val)
(*md)[key] = append((*md)[key], val)
return ctx
}
}
// SetResponseTrailer returns a ResponseFunc that sets the specified metadata
// SetRequestHeader returns a RequestFunc that sets the specified metadata
// key-value pair.
func SetResponseTrailer(key, val string) ServerResponseFunc {
return func(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
func SetRequestHeader(key, val string) RequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
key, val := EncodeKeyValue(key, val)
(*md)[key] = append((*md)[key], val)
return ctx

View File

@@ -1,28 +1,28 @@
package grpc
import (
oldcontext "golang.org/x/net/context"
"google.golang.org/grpc"
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
// Handler which should be called from the gRPC binding of the service
// Handler which should be called from the grpc binding of the service
// implementation. The incoming request parameter, and returned response
// parameter, are both gRPC types, not user-domain.
type Handler interface {
ServeGRPC(ctx oldcontext.Context, request interface{}) (oldcontext.Context, interface{}, error)
ServeGRPC(ctx context.Context, request interface{}) (context.Context, interface{}, error)
}
// Server wraps an endpoint and implements grpc.Handler.
type Server struct {
ctx context.Context
e endpoint.Endpoint
dec DecodeRequestFunc
enc EncodeResponseFunc
before []ServerRequestFunc
after []ServerResponseFunc
before []RequestFunc
after []ResponseFunc
logger log.Logger
}
@@ -32,12 +32,14 @@ type Server struct {
// definitions to individual handlers. Request and response objects are from the
// caller business domain, not gRPC request and reply types.
func NewServer(
ctx context.Context,
e endpoint.Endpoint,
dec DecodeRequestFunc,
enc EncodeResponseFunc,
options ...ServerOption,
) *Server {
s := &Server{
ctx: ctx,
e: e,
dec: dec,
enc: enc,
@@ -54,14 +56,14 @@ type ServerOption func(*Server)
// ServerBefore functions are executed on the HTTP request object before the
// request is decoded.
func ServerBefore(before ...ServerRequestFunc) ServerOption {
return func(s *Server) { s.before = append(s.before, before...) }
func ServerBefore(before ...RequestFunc) ServerOption {
return func(s *Server) { s.before = before }
}
// ServerAfter functions are executed on the HTTP response writer after the
// endpoint is invoked, but before anything is written to the client.
func ServerAfter(after ...ServerResponseFunc) ServerOption {
return func(s *Server) { s.after = append(s.after, after...) }
func ServerAfter(after ...ResponseFunc) ServerOption {
return func(s *Server) { s.after = after }
}
// ServerErrorLogger is used to log non-terminal errors. By default, no errors
@@ -71,53 +73,56 @@ func ServerErrorLogger(logger log.Logger) ServerOption {
}
// ServeGRPC implements the Handler interface.
func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) {
func (s Server) ServeGRPC(grpcCtx context.Context, req interface{}) (context.Context, interface{}, error) {
ctx := s.ctx
// Retrieve gRPC metadata.
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromContext(grpcCtx)
if !ok {
md = metadata.MD{}
}
for _, f := range s.before {
ctx = f(ctx, md)
ctx = f(ctx, &md)
}
request, err := s.dec(ctx, req)
// Store potentially updated metadata in the gRPC context.
grpcCtx = metadata.NewContext(grpcCtx, md)
request, err := s.dec(grpcCtx, req)
if err != nil {
s.logger.Log("err", err)
return ctx, nil, err
return grpcCtx, nil, BadRequestError{err}
}
response, err := s.e(ctx, request)
if err != nil {
s.logger.Log("err", err)
return ctx, nil, err
return grpcCtx, nil, err
}
var mdHeader, mdTrailer metadata.MD
for _, f := range s.after {
ctx = f(ctx, &mdHeader, &mdTrailer)
f(ctx, &md)
}
grpcResp, err := s.enc(ctx, response)
// Store potentially updated metadata in the gRPC context.
grpcCtx = metadata.NewContext(grpcCtx, md)
grpcResp, err := s.enc(grpcCtx, response)
if err != nil {
s.logger.Log("err", err)
return ctx, nil, err
return grpcCtx, nil, err
}
if len(mdHeader) > 0 {
if err = grpc.SendHeader(ctx, mdHeader); err != nil {
s.logger.Log("err", err)
return ctx, nil, err
}
}
if len(mdTrailer) > 0 {
if err = grpc.SetTrailer(ctx, mdTrailer); err != nil {
s.logger.Log("err", err)
return ctx, nil, err
}
}
return ctx, grpcResp, nil
return grpcCtx, grpcResp, nil
}
// BadRequestError is an error in decoding the request.
type BadRequestError struct {
Err error
}
// Error implements the error interface.
func (err BadRequestError) Error() string {
return err.Err.Error()
}

View File

@@ -1,14 +1,10 @@
package http
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"io/ioutil"
"net/http"
"net/url"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
"github.com/go-kit/kit/endpoint"
@@ -62,14 +58,14 @@ func SetClient(client *http.Client) ClientOption {
// ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP
// request before it's invoked.
func ClientBefore(before ...RequestFunc) ClientOption {
return func(c *Client) { c.before = append(c.before, before...) }
return func(c *Client) { c.before = before }
}
// ClientAfter sets the ClientResponseFuncs applied to the incoming HTTP
// request prior to it being decoded. This is useful for obtaining anything off
// of the response and adding onto the context prior to decoding.
func ClientAfter(after ...ClientResponseFunc) ClientOption {
return func(c *Client) { c.after = append(c.after, after...) }
return func(c *Client) { c.after = after }
}
// BufferedStream sets whether the Response.Body is left open, allowing it
@@ -86,11 +82,11 @@ func (c Client) Endpoint() endpoint.Endpoint {
req, err := http.NewRequest(c.method, c.tgt.String(), nil)
if err != nil {
return nil, err
return nil, Error{Domain: DomainNewRequest, Err: err}
}
if err = c.enc(ctx, req, request); err != nil {
return nil, err
return nil, Error{Domain: DomainEncode, Err: err}
}
for _, f := range c.before {
@@ -99,7 +95,7 @@ func (c Client) Endpoint() endpoint.Endpoint {
resp, err := ctxhttp.Do(ctx, c.client, req)
if err != nil {
return nil, err
return nil, Error{Domain: DomainDo, Err: err}
}
if !c.bufferedStream {
defer resp.Body.Close()
@@ -111,40 +107,9 @@ func (c Client) Endpoint() endpoint.Endpoint {
response, err := c.dec(ctx, resp)
if err != nil {
return nil, err
return nil, Error{Domain: DomainDecode, Err: err}
}
return response, nil
}
}
// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a
// JSON object to the Request body. Many JSON-over-HTTP services can use it as
// a sensible default. If the request implements Headerer, the provided headers
// will be applied to the request.
func EncodeJSONRequest(c context.Context, r *http.Request, request interface{}) error {
r.Header.Set("Content-Type", "application/json; charset=utf-8")
if headerer, ok := request.(Headerer); ok {
for k := range headerer.Headers() {
r.Header.Set(k, headerer.Headers().Get(k))
}
}
var b bytes.Buffer
r.Body = ioutil.NopCloser(&b)
return json.NewEncoder(&b).Encode(request)
}
// EncodeXMLRequest is an EncodeRequestFunc that serializes the request as a
// XML object to the Request body. If the request implements Headerer,
// the provided headers will be applied to the request.
func EncodeXMLRequest(c context.Context, r *http.Request, request interface{}) error {
r.Header.Set("Content-Type", "text/xml; charset=utf-8")
if headerer, ok := request.(Headerer); ok {
for k := range headerer.Headers() {
r.Header.Set(k, headerer.Headers().Get(k))
}
}
var b bytes.Buffer
r.Body = ioutil.NopCloser(&b)
return xml.NewEncoder(&b).Encode(request)
}

View File

@@ -1,15 +1,15 @@
package http_test
import (
"context"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"golang.org/x/net/context"
httptransport "github.com/go-kit/kit/transport/http"
)
@@ -140,68 +140,6 @@ func TestHTTPClientBufferedStream(t *testing.T) {
}
}
func TestEncodeJSONRequest(t *testing.T) {
var header http.Header
var body string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
if err != nil && err != io.EOF {
t.Fatal(err)
}
header = r.Header
body = string(b)
}))
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
client := httptransport.NewClient(
"POST",
serverURL,
httptransport.EncodeJSONRequest,
func(context.Context, *http.Response) (interface{}, error) { return nil, nil },
).Endpoint()
for _, test := range []struct {
value interface{}
body string
}{
{nil, "null\n"},
{12, "12\n"},
{1.2, "1.2\n"},
{true, "true\n"},
{"test", "\"test\"\n"},
{enhancedRequest{Foo: "foo"}, "{\"foo\":\"foo\"}\n"},
} {
if _, err := client(context.Background(), test.value); err != nil {
t.Error(err)
continue
}
if body != test.body {
t.Errorf("%v: actual %#v, expected %#v", test.value, body, test.body)
}
}
if _, err := client(context.Background(), enhancedRequest{Foo: "foo"}); err != nil {
t.Fatal(err)
}
if _, ok := header["X-Edward"]; !ok {
t.Fatalf("X-Edward value: actual %v, expected %v", nil, []string{"Snowden"})
}
if v := header.Get("X-Edward"); v != "Snowden" {
t.Errorf("X-Edward string: actual %v, expected %v", v, "Snowden")
}
}
func mustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
@@ -209,9 +147,3 @@ func mustParse(s string) *url.URL {
}
return u
}
type enhancedRequest struct {
Foo string `json:"foo"`
}
func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }

View File

@@ -1,8 +1,9 @@
package http
import (
"context"
"net/http"
"golang.org/x/net/context"
)
// DecodeRequestFunc extracts a user-domain request object from an HTTP

33
vendor/github.com/go-kit/kit/transport/http/err.go generated vendored Normal file
View File

@@ -0,0 +1,33 @@
package http
import (
"fmt"
)
const (
// DomainNewRequest is an error during request generation.
DomainNewRequest = "NewRequest"
// DomainEncode is an error during request or response encoding.
DomainEncode = "Encode"
// DomainDo is an error during the execution phase of the request.
DomainDo = "Do"
// DomainDecode is an error during request or response decoding.
DomainDecode = "Decode"
)
// Error is an error that occurred at some phase within the transport.
type Error struct {
// Domain is the phase in which the error was generated.
Domain string
// Err is the concrete error.
Err error
}
// Error implements the error interface.
func (e Error) Error() string {
return fmt.Sprintf("%s: %s", e.Domain, e.Err)
}

View File

@@ -0,0 +1,56 @@
package http_test
import (
"errors"
"fmt"
"net/http"
"net/url"
"testing"
"golang.org/x/net/context"
httptransport "github.com/go-kit/kit/transport/http"
)
func TestClientEndpointEncodeError(t *testing.T) {
var (
sampleErr = errors.New("Oh no, an error")
enc = func(context.Context, *http.Request, interface{}) error { return sampleErr }
dec = func(context.Context, *http.Response) (interface{}, error) { return nil, nil }
)
u := &url.URL{
Scheme: "https",
Host: "localhost",
Path: "/does/not/matter",
}
c := httptransport.NewClient(
"GET",
u,
enc,
dec,
)
_, err := c.Endpoint()(context.Background(), nil)
if err == nil {
t.Fatal("err == nil")
}
e, ok := err.(httptransport.Error)
if !ok {
t.Fatal("err is not of type github.com/go-kit/kit/transport/http.Error")
}
if want, have := sampleErr, e.Err; want != have {
t.Fatalf("want %v, have %v", want, have)
}
}
func ExampleErrorOutput() {
sampleErr := errors.New("oh no, an error")
err := httptransport.Error{Domain: httptransport.DomainDo, Err: sampleErr}
fmt.Println(err)
// Output:
// Do: oh no, an error
}

View File

@@ -1,36 +0,0 @@
package http
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
)
func ExamplePopulateRequestContext() {
handler := NewServer(
func(ctx context.Context, request interface{}) (response interface{}, err error) {
fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string))
fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string))
fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string))
fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string))
return struct{}{}, nil
},
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
ServerBefore(PopulateRequestContext),
)
server := httptest.NewServer(handler)
defer server.Close()
req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil)
req.Header.Set("X-Request-Id", "a1b2c3d4e5")
http.DefaultClient.Do(req)
// Output:
// Method PATCH
// RequestPath /search
// RequestURI /search?q=sympatico
// X-Request-ID a1b2c3d4e5
}

View File

@@ -1,8 +1,9 @@
package http
import (
"context"
"net/http"
"golang.org/x/net/context"
)
// RequestFunc may take information from an HTTP request and put it into a
@@ -21,13 +22,13 @@ type ServerResponseFunc func(context.Context, http.ResponseWriter) context.Conte
// clients, after a request has been made, but prior to it being decoded.
type ClientResponseFunc func(context.Context, *http.Response) context.Context
// SetContentType returns a ServerResponseFunc that sets the Content-Type header
// to the provided value.
// SetContentType returns a ResponseFunc that sets the Content-Type header to
// the provided value.
func SetContentType(contentType string) ServerResponseFunc {
return SetResponseHeader("Content-Type", contentType)
}
// SetResponseHeader returns a ServerResponseFunc that sets the given header.
// SetResponseHeader returns a ResponseFunc that sets the specified header.
func SetResponseHeader(key, val string) ServerResponseFunc {
return func(ctx context.Context, w http.ResponseWriter) context.Context {
w.Header().Set(key, val)
@@ -35,94 +36,10 @@ func SetResponseHeader(key, val string) ServerResponseFunc {
}
}
// SetRequestHeader returns a RequestFunc that sets the given header.
// SetRequestHeader returns a RequestFunc that sets the specified header.
func SetRequestHeader(key, val string) RequestFunc {
return func(ctx context.Context, r *http.Request) context.Context {
r.Header.Set(key, val)
return ctx
}
}
// PopulateRequestContext is a RequestFunc that populates several values into
// the context from the HTTP request. Those values may be extracted using the
// corresponding ContextKey type in this package.
func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context {
for k, v := range map[contextKey]string{
ContextKeyRequestMethod: r.Method,
ContextKeyRequestURI: r.RequestURI,
ContextKeyRequestPath: r.URL.Path,
ContextKeyRequestProto: r.Proto,
ContextKeyRequestHost: r.Host,
ContextKeyRequestRemoteAddr: r.RemoteAddr,
ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"),
ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"),
ContextKeyRequestAuthorization: r.Header.Get("Authorization"),
ContextKeyRequestReferer: r.Header.Get("Referer"),
ContextKeyRequestUserAgent: r.Header.Get("User-Agent"),
ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"),
} {
ctx = context.WithValue(ctx, k, v)
}
return ctx
}
type contextKey int
const (
// ContextKeyRequestMethod is populated in the context by
// PopulateRequestContext. Its value is r.Method.
ContextKeyRequestMethod contextKey = iota
// ContextKeyRequestURI is populated in the context by
// PopulateRequestContext. Its value is r.RequestURI.
ContextKeyRequestURI
// ContextKeyRequestPath is populated in the context by
// PopulateRequestContext. Its value is r.URL.Path.
ContextKeyRequestPath
// ContextKeyRequestProto is populated in the context by
// PopulateRequestContext. Its value is r.Proto.
ContextKeyRequestProto
// ContextKeyRequestHost is populated in the context by
// PopulateRequestContext. Its value is r.Host.
ContextKeyRequestHost
// ContextKeyRequestRemoteAddr is populated in the context by
// PopulateRequestContext. Its value is r.RemoteAddr.
ContextKeyRequestRemoteAddr
// ContextKeyRequestXForwardedFor is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For").
ContextKeyRequestXForwardedFor
// ContextKeyRequestXForwardedProto is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto").
ContextKeyRequestXForwardedProto
// ContextKeyRequestAuthorization is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("Authorization").
ContextKeyRequestAuthorization
// ContextKeyRequestReferer is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("Referer").
ContextKeyRequestReferer
// ContextKeyRequestUserAgent is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("User-Agent").
ContextKeyRequestUserAgent
// ContextKeyRequestXRequestID is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("X-Request-Id").
ContextKeyRequestXRequestID
// ContextKeyResponseHeaders is populated in the context whenever a
// ServerFinalizerFunc is specified. Its value is of type http.Header, and
// is captured only once the entire response has been written.
ContextKeyResponseHeaders
// ContextKeyResponseSize is populated in the context whenever a
// ServerFinalizerFunc is specified. Its value is of type int64.
ContextKeyResponseSize
)

View File

@@ -1,10 +1,11 @@
package http_test
import (
"context"
"net/http/httptest"
"testing"
"golang.org/x/net/context"
httptransport "github.com/go-kit/kit/transport/http"
)

View File

@@ -1,39 +1,41 @@
package http
import (
"context"
"encoding/json"
"net/http"
"golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
// Server wraps an endpoint and implements http.Handler.
type Server struct {
ctx context.Context
e endpoint.Endpoint
dec DecodeRequestFunc
enc EncodeResponseFunc
before []RequestFunc
after []ServerResponseFunc
errorEncoder ErrorEncoder
finalizer ServerFinalizerFunc
logger log.Logger
}
// NewServer constructs a new server, which implements http.Server and wraps
// the provided endpoint.
func NewServer(
ctx context.Context,
e endpoint.Endpoint,
dec DecodeRequestFunc,
enc EncodeResponseFunc,
options ...ServerOption,
) *Server {
s := &Server{
ctx: ctx,
e: e,
dec: dec,
enc: enc,
errorEncoder: DefaultErrorEncoder,
errorEncoder: defaultErrorEncoder,
logger: log.NewNopLogger(),
}
for _, option := range options {
@@ -48,51 +50,33 @@ type ServerOption func(*Server)
// ServerBefore functions are executed on the HTTP request object before the
// request is decoded.
func ServerBefore(before ...RequestFunc) ServerOption {
return func(s *Server) { s.before = append(s.before, before...) }
return func(s *Server) { s.before = before }
}
// ServerAfter functions are executed on the HTTP response writer after the
// endpoint is invoked, but before anything is written to the client.
func ServerAfter(after ...ServerResponseFunc) ServerOption {
return func(s *Server) { s.after = append(s.after, after...) }
return func(s *Server) { s.after = after }
}
// ServerErrorEncoder is used to encode errors to the http.ResponseWriter
// whenever they're encountered in the processing of a request. Clients can
// use this to provide custom error formatting and response codes. By default,
// errors will be written with the DefaultErrorEncoder.
// errors will be written as plain text with an appropriate, if generic,
// status code.
func ServerErrorEncoder(ee ErrorEncoder) ServerOption {
return func(s *Server) { s.errorEncoder = ee }
}
// ServerErrorLogger is used to log non-terminal errors. By default, no errors
// are logged. This is intended as a diagnostic measure. Finer-grained control
// of error handling, including logging in more detail, should be performed in a
// custom ServerErrorEncoder or ServerFinalizer, both of which have access to
// the context.
// are logged.
func ServerErrorLogger(logger log.Logger) ServerOption {
return func(s *Server) { s.logger = logger }
}
// ServerFinalizer is executed at the end of every HTTP request.
// By default, no finalizer is registered.
func ServerFinalizer(f ServerFinalizerFunc) ServerOption {
return func(s *Server) { s.finalizer = f }
}
// ServeHTTP implements http.Handler.
func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.finalizer != nil {
iw := &interceptingWriter{w, http.StatusOK, 0}
defer func() {
ctx = context.WithValue(ctx, ContextKeyResponseHeaders, iw.Header())
ctx = context.WithValue(ctx, ContextKeyResponseSize, iw.written)
s.finalizer(ctx, iw.code, r)
}()
w = iw
}
ctx := s.ctx
for _, f := range s.before {
ctx = f(ctx, r)
@@ -101,14 +85,14 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
request, err := s.dec(ctx, r)
if err != nil {
s.logger.Log("err", err)
s.errorEncoder(ctx, err, w)
s.errorEncoder(ctx, Error{Domain: DomainDecode, Err: err}, w)
return
}
response, err := s.e(ctx, request)
if err != nil {
s.logger.Log("err", err)
s.errorEncoder(ctx, err, w)
s.errorEncoder(ctx, Error{Domain: DomainDo, Err: err}, w)
return
}
@@ -118,104 +102,32 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := s.enc(ctx, w, response); err != nil {
s.logger.Log("err", err)
s.errorEncoder(ctx, err, w)
s.errorEncoder(ctx, Error{Domain: DomainEncode, Err: err}, w)
return
}
}
// ErrorEncoder is responsible for encoding an error to the ResponseWriter.
// Users are encouraged to use custom ErrorEncoders to encode HTTP errors to
// their clients, and will likely want to pass and check for their own error
// types. See the example shipping/handling service.
//
// In the server implementation, only kit/transport/http.Error values are ever
// passed to this function, so you might be tempted to have this function take
// one of those directly. But, users are encouraged to use custom ErrorEncoders
// to encode all HTTP errors to their clients, and so may want to pass and check
// for their own error types. See the example shipping/handling service.
type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter)
// ServerFinalizerFunc can be used to perform work at the end of an HTTP
// request, after the response has been written to the client. The principal
// intended use is for request logging. In addition to the response code
// provided in the function signature, additional response parameters are
// provided in the context under keys with the ContextKeyResponse prefix.
type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request)
// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a
// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as
// a sensible default. If the response implements Headerer, the provided headers
// will be applied to the response. If the response implements StatusCoder, the
// provided StatusCode will be used instead of 200.
func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if headerer, ok := response.(Headerer); ok {
for k := range headerer.Headers() {
w.Header().Set(k, headerer.Headers().Get(k))
func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
switch e := err.(type) {
case Error:
switch e.Domain {
case DomainDecode:
http.Error(w, err.Error(), http.StatusBadRequest)
case DomainDo:
http.Error(w, err.Error(), http.StatusServiceUnavailable) // too aggressive?
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
code := http.StatusOK
if sc, ok := response.(StatusCoder); ok {
code = sc.StatusCode()
}
w.WriteHeader(code)
if code == http.StatusNoContent {
return nil
}
return json.NewEncoder(w).Encode(response)
}
// DefaultErrorEncoder writes the error to the ResponseWriter, by default a
// content type of text/plain, a body of the plain text of the error, and a
// status code of 500. If the error implements Headerer, the provided headers
// will be applied to the response. If the error implements json.Marshaler, and
// the marshaling succeeds, a content type of application/json and the JSON
// encoded form of the error will be used. If the error implements StatusCoder,
// the provided StatusCode will be used instead of 500.
func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
contentType, body := "text/plain; charset=utf-8", []byte(err.Error())
if marshaler, ok := err.(json.Marshaler); ok {
if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil {
contentType, body = "application/json; charset=utf-8", jsonBody
}
}
w.Header().Set("Content-Type", contentType)
if headerer, ok := err.(Headerer); ok {
for k := range headerer.Headers() {
w.Header().Set(k, headerer.Headers().Get(k))
}
}
code := http.StatusInternalServerError
if sc, ok := err.(StatusCoder); ok {
code = sc.StatusCode()
}
w.WriteHeader(code)
w.Write(body)
}
// StatusCoder is checked by DefaultErrorEncoder. If an error value implements
// StatusCoder, the StatusCode will be used when encoding the error. By default,
// StatusInternalServerError (500) is used.
type StatusCoder interface {
StatusCode() int
}
// Headerer is checked by DefaultErrorEncoder. If an error value implements
// Headerer, the provided headers will be applied to the response writer, after
// the Content-Type is set.
type Headerer interface {
Headers() http.Header
}
type interceptingWriter struct {
http.ResponseWriter
code int
written int64
}
// WriteHeader may not be explicitly called, so care must be taken to
// initialize w.code to its default value of http.StatusOK.
func (w *interceptingWriter) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *interceptingWriter) Write(p []byte) (int, error) {
n, err := w.ResponseWriter.Write(p)
w.written += int64(n)
return n, err
}

View File

@@ -1,21 +1,20 @@
package http_test
import (
"context"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-kit/kit/endpoint"
"golang.org/x/net/context"
httptransport "github.com/go-kit/kit/transport/http"
)
func TestServerBadDecode(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
@@ -23,13 +22,14 @@ func TestServerBadDecode(t *testing.T) {
server := httptest.NewServer(handler)
defer server.Close()
resp, _ := http.Get(server.URL)
if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
if want, have := http.StatusBadRequest, resp.StatusCode; want != have {
t.Errorf("want %d, have %d", want, have)
}
}
func TestServerBadEndpoint(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
@@ -37,13 +37,14 @@ func TestServerBadEndpoint(t *testing.T) {
server := httptest.NewServer(handler)
defer server.Close()
resp, _ := http.Get(server.URL)
if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
if want, have := http.StatusServiceUnavailable, resp.StatusCode; want != have {
t.Errorf("want %d, have %d", want, have)
}
}
func TestServerBadEncode(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") },
@@ -59,12 +60,13 @@ func TestServerBadEncode(t *testing.T) {
func TestServerErrorEncoder(t *testing.T) {
errTeapot := errors.New("teapot")
code := func(err error) int {
if err == errTeapot {
if e, ok := err.(httptransport.Error); ok && e.Err == errTeapot {
return http.StatusTeapot
}
return http.StatusInternalServerError
}
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
@@ -79,7 +81,7 @@ func TestServerErrorEncoder(t *testing.T) {
}
func TestServerHappyPath(t *testing.T) {
step, response := testServer(t)
_, step, response := testServer(t)
step()
resp := <-response
defer resp.Body.Close()
@@ -89,245 +91,14 @@ func TestServerHappyPath(t *testing.T) {
}
}
func TestMultipleServerBefore(t *testing.T) {
func testServer(t *testing.T) (cancel, step func(), resp <-chan *http.Response) {
var (
headerKey = "X-Henlo-Lizer"
headerVal = "Helllo you stinky lizard"
statusCode = http.StatusTeapot
responseBody = "go eat a fly ugly\n"
done = make(chan struct{})
)
handler := httptransport.NewServer(
endpoint.Nop,
func(context.Context, *http.Request) (interface{}, error) {
return struct{}{}, nil
},
func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
w.Header().Set(headerKey, headerVal)
w.WriteHeader(statusCode)
w.Write([]byte(responseBody))
return nil
},
httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
ctx = context.WithValue(ctx, "one", 1)
return ctx
}),
httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
if _, ok := ctx.Value("one").(int); !ok {
t.Error("Value was not set properly when multiple ServerBefores are used")
}
close(done)
return ctx
}),
)
server := httptest.NewServer(handler)
defer server.Close()
go http.Get(server.URL)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for finalizer")
}
}
func TestMultipleServerAfter(t *testing.T) {
var (
headerKey = "X-Henlo-Lizer"
headerVal = "Helllo you stinky lizard"
statusCode = http.StatusTeapot
responseBody = "go eat a fly ugly\n"
done = make(chan struct{})
)
handler := httptransport.NewServer(
endpoint.Nop,
func(context.Context, *http.Request) (interface{}, error) {
return struct{}{}, nil
},
func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
w.Header().Set(headerKey, headerVal)
w.WriteHeader(statusCode)
w.Write([]byte(responseBody))
return nil
},
httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
ctx = context.WithValue(ctx, "one", 1)
return ctx
}),
httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
if _, ok := ctx.Value("one").(int); !ok {
t.Error("Value was not set properly when multiple ServerAfters are used")
}
close(done)
return ctx
}),
)
server := httptest.NewServer(handler)
defer server.Close()
go http.Get(server.URL)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for finalizer")
}
}
func TestServerFinalizer(t *testing.T) {
var (
headerKey = "X-Henlo-Lizer"
headerVal = "Helllo you stinky lizard"
statusCode = http.StatusTeapot
responseBody = "go eat a fly ugly\n"
done = make(chan struct{})
)
handler := httptransport.NewServer(
endpoint.Nop,
func(context.Context, *http.Request) (interface{}, error) {
return struct{}{}, nil
},
func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
w.Header().Set(headerKey, headerVal)
w.WriteHeader(statusCode)
w.Write([]byte(responseBody))
return nil
},
httptransport.ServerFinalizer(func(ctx context.Context, code int, _ *http.Request) {
if want, have := statusCode, code; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header)
if want, have := headerVal, responseHeader.Get(headerKey); want != have {
t.Errorf("%s: want %q, have %q", headerKey, want, have)
}
responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64)
if want, have := int64(len(responseBody)), responseSize; want != have {
t.Errorf("response size: want %d, have %d", want, have)
}
close(done)
}),
)
server := httptest.NewServer(handler)
defer server.Close()
go http.Get(server.URL)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for finalizer")
}
}
type enhancedResponse struct {
Foo string `json:"foo"`
}
func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired }
func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
func TestEncodeJSONResponse(t *testing.T) {
handler := httptransport.NewServer(
func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
httptransport.EncodeJSONResponse,
)
server := httptest.NewServer(handler)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
t.Errorf("X-Edward: want %q, have %q", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
t.Errorf("Body: want %s, have %s", want, have)
}
}
type noContentResponse struct{}
func (e noContentResponse) StatusCode() int { return http.StatusNoContent }
func TestEncodeNoContent(t *testing.T) {
handler := httptransport.NewServer(
func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
httptransport.EncodeJSONResponse,
)
server := httptest.NewServer(handler)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if want, have := http.StatusNoContent, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := 0, len(buf); want != have {
t.Errorf("Body: want no content, have %d bytes", have)
}
}
type enhancedError struct{}
func (e enhancedError) Error() string { return "enhanced error" }
func (e enhancedError) StatusCode() int { return http.StatusTeapot }
func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} }
func TestEnhancedError(t *testing.T) {
handler := httptransport.NewServer(
func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil },
)
server := httptest.NewServer(handler)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if want, have := http.StatusTeapot, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
if want, have := "1", resp.Header.Get("X-Enhanced"); want != have {
t.Errorf("X-Enhanced: want %q, have %q", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
t.Errorf("Body: want %s, have %s", want, have)
}
}
func testServer(t *testing.T) (step func(), resp <-chan *http.Response) {
var (
stepch = make(chan bool)
endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil }
response = make(chan *http.Response)
handler = httptransport.NewServer(
ctx, cancelfn = context.WithCancel(context.Background())
stepch = make(chan bool)
endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil }
response = make(chan *http.Response)
handler = httptransport.NewServer(
ctx,
endpoint,
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
@@ -345,5 +116,5 @@ func testServer(t *testing.T) (step func(), resp <-chan *http.Response) {
}
response <- resp
}()
return func() { stepch <- true }, response
return cancelfn, func() { stepch <- true }, response
}

View File

@@ -1,10 +1,11 @@
package httprp
import (
"context"
"net/http"
"net/http/httputil"
"net/url"
"golang.org/x/net/context"
)
// RequestFunc may take information from an HTTP request and put it into a
@@ -14,6 +15,7 @@ type RequestFunc func(context.Context, *http.Request) context.Context
// Server is a proxying request handler.
type Server struct {
ctx context.Context
proxy http.Handler
before []RequestFunc
errorEncoder func(w http.ResponseWriter, err error)
@@ -24,10 +26,12 @@ type Server struct {
// If the target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
func NewServer(
ctx context.Context,
baseURL *url.URL,
options ...ServerOption,
) *Server {
s := &Server{
ctx: ctx,
proxy: httputil.NewSingleHostReverseProxy(baseURL),
}
for _, option := range options {
@@ -42,12 +46,12 @@ type ServerOption func(*Server)
// ServerBefore functions are executed on the HTTP request object before the
// request is decoded.
func ServerBefore(before ...RequestFunc) ServerOption {
return func(s *Server) { s.before = append(s.before, before...) }
return func(s *Server) { s.before = before }
}
// ServeHTTP implements http.Handler.
func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx := s.ctx
for _, f := range s.before {
ctx = f(ctx, r)

View File

@@ -1,13 +1,14 @@
package httprp_test
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"golang.org/x/net/context"
httptransport "github.com/go-kit/kit/transport/httprp"
)
@@ -20,6 +21,7 @@ func TestServerHappyPathSingleServer(t *testing.T) {
originURL, _ := url.Parse(originServer.URL)
handler := httptransport.NewServer(
context.Background(),
originURL,
)
proxyServer := httptest.NewServer(handler)
@@ -54,6 +56,7 @@ func TestServerHappyPathSingleServerWithServerOptions(t *testing.T) {
originURL, _ := url.Parse(originServer.URL)
handler := httptransport.NewServer(
context.Background(),
originURL,
httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
r.Header.Add(headerKey, headerVal)
@@ -80,6 +83,7 @@ func TestServerOriginServerNotFoundResponse(t *testing.T) {
originURL, _ := url.Parse(originServer.URL)
handler := httptransport.NewServer(
context.Background(),
originURL,
)
proxyServer := httptest.NewServer(handler)
@@ -100,6 +104,7 @@ func TestServerOriginServerUnreachable(t *testing.T) {
originServer.Close()
handler := httptransport.NewServer(
context.Background(),
originURL,
)
proxyServer := httptest.NewServer(handler)
@@ -115,44 +120,3 @@ func TestServerOriginServerUnreachable(t *testing.T) {
t.Errorf("want %d or %d, have %d", http.StatusBadGateway, http.StatusInternalServerError, resp.StatusCode)
}
}
func TestMultipleServerBefore(t *testing.T) {
const (
headerKey = "X-TEST-HEADER"
headerVal = "go-kit-proxy"
)
originServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if want, have := headerVal, r.Header.Get(headerKey); want != have {
t.Errorf("want %q, have %q", want, have)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("hey"))
}))
defer originServer.Close()
originURL, _ := url.Parse(originServer.URL)
handler := httptransport.NewServer(
originURL,
httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
r.Header.Add(headerKey, headerVal)
return ctx
}),
httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
return ctx
}),
)
proxyServer := httptest.NewServer(handler)
defer proxyServer.Close()
resp, _ := http.Get(proxyServer.URL)
if want, have := http.StatusOK, resp.StatusCode; want != have {
t.Errorf("want %d, have %d", want, have)
}
responseBody, _ := ioutil.ReadAll(resp.Body)
if want, have := "hey", string(responseBody); want != have {
t.Errorf("want %q, have %q", want, have)
}
}