Add pluggable codec support
This commit is contained in:
		
							
								
								
									
										75
									
								
								client/codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								client/codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | |||||||
|  | package client | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"net/rpc" | ||||||
|  |  | ||||||
|  | 	"github.com/youtube/vitess/go/rpcplus" | ||||||
|  | 	"github.com/youtube/vitess/go/rpcplus/jsonrpc" | ||||||
|  | 	"github.com/youtube/vitess/go/rpcplus/pbrpc" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	defaultContentType = "application/octet-stream" | ||||||
|  |  | ||||||
|  | 	defaultCodecs = map[string]codecFunc{ | ||||||
|  | 		"application/json":         jsonrpc.NewClientCodec, | ||||||
|  | 		"application/json-rpc":     jsonrpc.NewClientCodec, | ||||||
|  | 		"application/protobuf":     pbrpc.NewClientCodec, | ||||||
|  | 		"application/proto-rpc":    pbrpc.NewClientCodec, | ||||||
|  | 		"application/octet-stream": pbrpc.NewClientCodec, | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type CodecFunc func(io.ReadWriteCloser) rpc.ClientCodec | ||||||
|  |  | ||||||
|  | // only for internal use | ||||||
|  | type codecFunc func(io.ReadWriteCloser) rpcplus.ClientCodec | ||||||
|  |  | ||||||
|  | // wraps an net/rpc ClientCodec to provide an rpcplus.ClientCodec | ||||||
|  | // temporary until we strip out use of rpcplus | ||||||
|  | type rpcCodecWrap struct { | ||||||
|  | 	r rpc.ClientCodec | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) WriteRequest(r *rpcplus.Request, b interface{}) error { | ||||||
|  | 	rc := &rpc.Request{ | ||||||
|  | 		ServiceMethod: r.ServiceMethod, | ||||||
|  | 		Seq:           r.Seq, | ||||||
|  | 	} | ||||||
|  | 	err := cw.r.WriteRequest(rc, b) | ||||||
|  | 	r.ServiceMethod = rc.ServiceMethod | ||||||
|  | 	r.Seq = rc.Seq | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) ReadResponseHeader(r *rpcplus.Response) error { | ||||||
|  | 	rc := &rpc.Response{ | ||||||
|  | 		ServiceMethod: r.ServiceMethod, | ||||||
|  | 		Seq:           r.Seq, | ||||||
|  | 		Error:         r.Error, | ||||||
|  | 	} | ||||||
|  | 	err := cw.r.ReadResponseHeader(rc) | ||||||
|  | 	r.ServiceMethod = rc.ServiceMethod | ||||||
|  | 	r.Seq = rc.Seq | ||||||
|  | 	r.Error = r.Error | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) ReadResponseBody(b interface{}) error { | ||||||
|  | 	return cw.r.ReadResponseBody(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) Close() error { | ||||||
|  | 	return cw.r.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // wraps a CodecFunc to provide an internal codecFunc | ||||||
|  | // temporary until we strip rpcplus out | ||||||
|  | func codecWrap(cf CodecFunc) codecFunc { | ||||||
|  | 	return func(rwc io.ReadWriteCloser) rpcplus.ClientCodec { | ||||||
|  | 		return &rpcCodecWrap{ | ||||||
|  | 			r: cf(rwc), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type options struct { | type options struct { | ||||||
|  | 	contentType string | ||||||
|  | 	codecs      map[string]CodecFunc | ||||||
| 	broker      broker.Broker | 	broker      broker.Broker | ||||||
| 	registry    registry.Registry | 	registry    registry.Registry | ||||||
| 	transport   transport.Transport | 	transport   transport.Transport | ||||||
| @@ -18,6 +20,18 @@ func Broker(b broker.Broker) Option { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Codec(contentType string, cf CodecFunc) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.codecs[contentType] = cf | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ContentType(ct string) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.contentType = ct | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func Registry(r registry.Registry) Option { | func Registry(r registry.Registry) Option { | ||||||
| 	return func(o *options) { | 	return func(o *options) { | ||||||
| 		o.registry = r | 		o.registry = r | ||||||
|   | |||||||
| @@ -28,12 +28,18 @@ type rpcClient struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func newRpcClient(opt ...Option) Client { | func newRpcClient(opt ...Option) Client { | ||||||
| 	var opts options | 	opts := options{ | ||||||
|  | 		codecs: make(map[string]CodecFunc), | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	for _, o := range opt { | 	for _, o := range opt { | ||||||
| 		o(&opts) | 		o(&opts) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if len(opts.contentType) == 0 { | ||||||
|  | 		opts.contentType = defaultContentType | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if opts.transport == nil { | 	if opts.transport == nil { | ||||||
| 		opts.transport = transport.DefaultTransport | 		opts.transport = transport.DefaultTransport | ||||||
| 	} | 	} | ||||||
| @@ -55,6 +61,16 @@ func (t *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) | |||||||
| 	return t.r.RoundTrip(r) | 	return t.r.RoundTrip(r) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (r *rpcClient) codecFunc(contentType string) (codecFunc, error) { | ||||||
|  | 	if cf, ok := r.opts.codecs[contentType]; ok { | ||||||
|  | 		return codecWrap(cf), nil | ||||||
|  | 	} | ||||||
|  | 	if cf, ok := defaultCodecs[contentType]; ok { | ||||||
|  | 		return cf, nil | ||||||
|  | 	} | ||||||
|  | 	return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (r *rpcClient) call(ctx context.Context, address string, request Request, response interface{}) error { | func (r *rpcClient) call(ctx context.Context, address string, request Request, response interface{}) error { | ||||||
| 	msg := &transport.Message{ | 	msg := &transport.Message{ | ||||||
| 		Header: make(map[string]string), | 		Header: make(map[string]string), | ||||||
| @@ -69,13 +85,18 @@ func (r *rpcClient) call(ctx context.Context, address string, request Request, r | |||||||
|  |  | ||||||
| 	msg.Header["Content-Type"] = request.ContentType() | 	msg.Header["Content-Type"] = request.ContentType() | ||||||
|  |  | ||||||
|  | 	cf, err := r.codecFunc(request.ContentType()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.InternalServerError("go.micro.client", err.Error()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	c, err := r.opts.transport.Dial(address) | 	c, err := r.opts.transport.Dial(address) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) | 		return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) | ||||||
| 	} | 	} | ||||||
| 	defer c.Close() | 	defer c.Close() | ||||||
|  |  | ||||||
| 	client := rpc.NewClientWithCodec(newRpcPlusCodec(msg, c)) | 	client := rpc.NewClientWithCodec(newRpcPlusCodec(msg, c, cf)) | ||||||
| 	err = client.Call(ctx, request.Method(), request.Request(), response) | 	err = client.Call(ctx, request.Method(), request.Request(), response) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @@ -97,12 +118,17 @@ func (r *rpcClient) stream(ctx context.Context, address string, request Request, | |||||||
|  |  | ||||||
| 	msg.Header["Content-Type"] = request.ContentType() | 	msg.Header["Content-Type"] = request.ContentType() | ||||||
|  |  | ||||||
|  | 	cf, err := r.codecFunc(request.ContentType()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errors.InternalServerError("go.micro.client", err.Error()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	c, err := r.opts.transport.Dial(address, transport.WithStream()) | 	c, err := r.opts.transport.Dial(address, transport.WithStream()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) | 		return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	client := rpc.NewClientWithCodec(newRpcPlusCodec(msg, c)) | 	client := rpc.NewClientWithCodec(newRpcPlusCodec(msg, c, cf)) | ||||||
| 	call := client.StreamGo(request.Method(), request.Request(), responseChan) | 	call := client.StreamGo(request.Method(), request.Request(), responseChan) | ||||||
|  |  | ||||||
| 	return &rpcStream{ | 	return &rpcStream{ | ||||||
| @@ -195,14 +221,14 @@ func (r *rpcClient) Publish(ctx context.Context, p Publication) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (r *rpcClient) NewPublication(topic string, message interface{}) Publication { | func (r *rpcClient) NewPublication(topic string, message interface{}) Publication { | ||||||
| 	return r.NewProtoPublication(topic, message) | 	return newRpcPublication(topic, message, r.opts.contentType) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *rpcClient) NewProtoPublication(topic string, message interface{}) Publication { | func (r *rpcClient) NewProtoPublication(topic string, message interface{}) Publication { | ||||||
| 	return newRpcPublication(topic, message, "application/octet-stream") | 	return newRpcPublication(topic, message, "application/octet-stream") | ||||||
| } | } | ||||||
| func (r *rpcClient) NewRequest(service, method string, request interface{}) Request { | func (r *rpcClient) NewRequest(service, method string, request interface{}) Request { | ||||||
| 	return r.NewProtoRequest(service, method, request) | 	return newRpcRequest(service, method, request, r.opts.contentType) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *rpcClient) NewProtoRequest(service, method string, request interface{}) Request { | func (r *rpcClient) NewProtoRequest(service, method string, request interface{}) Request { | ||||||
|   | |||||||
| @@ -2,12 +2,9 @@ package client | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"fmt" |  | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| 	rpc "github.com/youtube/vitess/go/rpcplus" | 	rpc "github.com/youtube/vitess/go/rpcplus" | ||||||
| 	js "github.com/youtube/vitess/go/rpcplus/jsonrpc" |  | ||||||
| 	pb "github.com/youtube/vitess/go/rpcplus/pbrpc" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type rpcPlusCodec struct { | type rpcPlusCodec struct { | ||||||
| @@ -37,50 +34,33 @@ func (rwc *readWriteCloser) Close() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func newRpcPlusCodec(req *transport.Message, client transport.Client) *rpcPlusCodec { | func newRpcPlusCodec(req *transport.Message, client transport.Client, cf codecFunc) *rpcPlusCodec { | ||||||
| 	r := &rpcPlusCodec{ | 	rwc := &readWriteCloser{ | ||||||
| 		req:    req, |  | ||||||
| 		client: client, |  | ||||||
| 		buf: &readWriteCloser{ |  | ||||||
| 		wbuf: bytes.NewBuffer(nil), | 		wbuf: bytes.NewBuffer(nil), | ||||||
| 		rbuf: bytes.NewBuffer(nil), | 		rbuf: bytes.NewBuffer(nil), | ||||||
| 		}, |  | ||||||
| 	} | 	} | ||||||
|  | 	r := &rpcPlusCodec{ | ||||||
| 	switch req.Header["Content-Type"] { | 		buf:    rwc, | ||||||
| 	case "application/octet-stream": | 		client: client, | ||||||
| 		r.codec = pb.NewClientCodec(r.buf) | 		codec:  cf(rwc), | ||||||
| 	case "application/json": | 		req:    req, | ||||||
| 		r.codec = js.NewClientCodec(r.buf) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *rpcPlusCodec) WriteRequest(req *rpc.Request, body interface{}) error { | func (c *rpcPlusCodec) WriteRequest(req *rpc.Request, body interface{}) error { | ||||||
| 	if c.codec == nil { |  | ||||||
| 		return fmt.Errorf("unsupported request type: %s", c.req.Header["Content-Type"]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := c.codec.WriteRequest(req, body); err != nil { | 	if err := c.codec.WriteRequest(req, body); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c.req.Body = c.buf.wbuf.Bytes() | 	c.req.Body = c.buf.wbuf.Bytes() | ||||||
| 	return c.client.Send(c.req) | 	return c.client.Send(c.req) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *rpcPlusCodec) ReadResponseHeader(r *rpc.Response) error { | func (c *rpcPlusCodec) ReadResponseHeader(r *rpc.Response) error { | ||||||
| 	var m transport.Message | 	var m transport.Message | ||||||
|  |  | ||||||
| 	if err := c.client.Recv(&m); err != nil { | 	if err := c.client.Recv(&m); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if c.codec == nil { |  | ||||||
| 		return fmt.Errorf("%s", string(m.Body)) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	c.buf.rbuf.Reset() | 	c.buf.rbuf.Reset() | ||||||
| 	c.buf.rbuf.Write(m.Body) | 	c.buf.rbuf.Write(m.Body) | ||||||
| 	return c.codec.ReadResponseHeader(r) | 	return c.codec.ReadResponseHeader(r) | ||||||
|   | |||||||
							
								
								
									
										73
									
								
								server/codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								server/codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | |||||||
|  | package server | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"net/rpc" | ||||||
|  |  | ||||||
|  | 	"github.com/youtube/vitess/go/rpcplus" | ||||||
|  | 	"github.com/youtube/vitess/go/rpcplus/jsonrpc" | ||||||
|  | 	"github.com/youtube/vitess/go/rpcplus/pbrpc" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type CodecFunc func(io.ReadWriteCloser) rpc.ServerCodec | ||||||
|  |  | ||||||
|  | // for internal use only | ||||||
|  | type codecFunc func(io.ReadWriteCloser) rpcplus.ServerCodec | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	defaultCodecs = map[string]codecFunc{ | ||||||
|  | 		"application/json":         jsonrpc.NewServerCodec, | ||||||
|  | 		"application/json-rpc":     jsonrpc.NewServerCodec, | ||||||
|  | 		"application/protobuf":     pbrpc.NewServerCodec, | ||||||
|  | 		"application/proto-rpc":    pbrpc.NewServerCodec, | ||||||
|  | 		"application/octet-stream": pbrpc.NewServerCodec, | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // wraps an net/rpc ServerCodec to provide an rpcplus.ServerCodec | ||||||
|  | // temporary until we strip out use of rpcplus | ||||||
|  | type rpcCodecWrap struct { | ||||||
|  | 	r rpc.ServerCodec | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) ReadRequestHeader(r *rpcplus.Request) error { | ||||||
|  | 	rc := &rpc.Request{ | ||||||
|  | 		ServiceMethod: r.ServiceMethod, | ||||||
|  | 		Seq:           r.Seq, | ||||||
|  | 	} | ||||||
|  | 	err := cw.r.ReadRequestHeader(rc) | ||||||
|  | 	r.ServiceMethod = rc.ServiceMethod | ||||||
|  | 	r.Seq = rc.Seq | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) ReadRequestBody(b interface{}) error { | ||||||
|  | 	return cw.r.ReadRequestBody(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) WriteResponse(r *rpcplus.Response, b interface{}, l bool) error { | ||||||
|  | 	rc := &rpc.Response{ | ||||||
|  | 		ServiceMethod: r.ServiceMethod, | ||||||
|  | 		Seq:           r.Seq, | ||||||
|  | 		Error:         r.Error, | ||||||
|  | 	} | ||||||
|  | 	err := cw.r.WriteResponse(rc, b) | ||||||
|  | 	r.ServiceMethod = rc.ServiceMethod | ||||||
|  | 	r.Seq = rc.Seq | ||||||
|  | 	r.Error = r.Error | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cw *rpcCodecWrap) Close() error { | ||||||
|  | 	return cw.r.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // wraps a CodecFunc to provide an internal codecFunc | ||||||
|  | // temporary until we strip rpcplus out | ||||||
|  | func codecWrap(cf CodecFunc) codecFunc { | ||||||
|  | 	return func(rwc io.ReadWriteCloser) rpcplus.ServerCodec { | ||||||
|  | 		return &rpcCodecWrap{ | ||||||
|  | 			r: cf(rwc), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type options struct { | type options struct { | ||||||
|  | 	codecs    map[string]CodecFunc | ||||||
| 	broker    broker.Broker | 	broker    broker.Broker | ||||||
| 	registry  registry.Registry | 	registry  registry.Registry | ||||||
| 	transport transport.Transport | 	transport transport.Transport | ||||||
| @@ -19,7 +20,9 @@ type options struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func newOptions(opt ...Option) options { | func newOptions(opt ...Option) options { | ||||||
| 	var opts options | 	opts := options{ | ||||||
|  | 		codecs: make(map[string]CodecFunc), | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	for _, o := range opt { | 	for _, o := range opt { | ||||||
| 		o(&opts) | 		o(&opts) | ||||||
| @@ -116,6 +119,12 @@ func Broker(b broker.Broker) Option { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Codec(contentType string, cf CodecFunc) Option { | ||||||
|  | 	return func(o *options) { | ||||||
|  | 		o.codecs[contentType] = cf | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func Registry(r registry.Registry) Option { | func Registry(r registry.Registry) Option { | ||||||
| 	return func(o *options) { | 	return func(o *options) { | ||||||
| 		o.registry = r | 		o.registry = r | ||||||
|   | |||||||
| @@ -2,11 +2,9 @@ package server | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| 	rpc "github.com/youtube/vitess/go/rpcplus" | 	rpc "github.com/youtube/vitess/go/rpcplus" | ||||||
| 	js "github.com/youtube/vitess/go/rpcplus/jsonrpc" |  | ||||||
| 	pb "github.com/youtube/vitess/go/rpcplus/pbrpc" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type rpcPlusCodec struct { | type rpcPlusCodec struct { | ||||||
| @@ -36,30 +34,21 @@ func (rwc *readWriteCloser) Close() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func newRpcPlusCodec(req *transport.Message, socket transport.Socket) rpc.ServerCodec { | func newRpcPlusCodec(req *transport.Message, socket transport.Socket, cf codecFunc) rpc.ServerCodec { | ||||||
| 	r := &rpcPlusCodec{ | 	rwc := &readWriteCloser{ | ||||||
| 		socket: socket, |  | ||||||
| 		req:    req, |  | ||||||
| 		buf: &readWriteCloser{ |  | ||||||
| 		rbuf: bytes.NewBuffer(req.Body), | 		rbuf: bytes.NewBuffer(req.Body), | ||||||
| 		wbuf: bytes.NewBuffer(nil), | 		wbuf: bytes.NewBuffer(nil), | ||||||
| 		}, |  | ||||||
| 	} | 	} | ||||||
|  | 	r := &rpcPlusCodec{ | ||||||
| 	switch req.Header["Content-Type"] { | 		buf:    rwc, | ||||||
| 	case "application/octet-stream": | 		codec:  cf(rwc), | ||||||
| 		r.codec = pb.NewServerCodec(r.buf) | 		req:    req, | ||||||
| 	case "application/json": | 		socket: socket, | ||||||
| 		r.codec = js.NewServerCodec(r.buf) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *rpcPlusCodec) ReadRequestHeader(r *rpc.Request) error { | func (c *rpcPlusCodec) ReadRequestHeader(r *rpc.Request) error { | ||||||
| 	if c.codec == nil { |  | ||||||
| 		return fmt.Errorf("unsupported content type %s", c.req.Header["Content-Type"]) |  | ||||||
| 	} |  | ||||||
| 	return c.codec.ReadRequestHeader(r) | 	return c.codec.ReadRequestHeader(r) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -68,19 +57,14 @@ func (c *rpcPlusCodec) ReadRequestBody(r interface{}) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *rpcPlusCodec) WriteResponse(r *rpc.Response, body interface{}, last bool) error { | func (c *rpcPlusCodec) WriteResponse(r *rpc.Response, body interface{}, last bool) error { | ||||||
| 	if c.codec == nil { |  | ||||||
| 		return fmt.Errorf("unsupported request type: %s", c.req.Header["Content-Type"]) |  | ||||||
| 	} |  | ||||||
| 	c.buf.wbuf.Reset() | 	c.buf.wbuf.Reset() | ||||||
| 	if err := c.codec.WriteResponse(r, body, last); err != nil { | 	if err := c.codec.WriteResponse(r, body, last); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return c.socket.Send(&transport.Message{ | 	return c.socket.Send(&transport.Message{ | ||||||
| 		Header: map[string]string{"Content-Type": c.req.Header["Content-Type"]}, | 		Header: map[string]string{"Content-Type": c.req.Header["Content-Type"]}, | ||||||
| 		Body:   c.buf.wbuf.Bytes(), | 		Body:   c.buf.wbuf.Bytes(), | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *rpcPlusCodec) Close() error { | func (c *rpcPlusCodec) Close() error { | ||||||
|   | |||||||
| @@ -43,7 +43,20 @@ func (s *rpcServer) accept(sock transport.Socket) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	codec := newRpcPlusCodec(&msg, sock) | 	cf, err := s.codecFunc(msg.Header["Content-Type"]) | ||||||
|  | 	// TODO: needs better error handling | ||||||
|  | 	if err != nil { | ||||||
|  | 		sock.Send(&transport.Message{ | ||||||
|  | 			Header: map[string]string{ | ||||||
|  | 				"Content-Type": "text/plain", | ||||||
|  | 			}, | ||||||
|  | 			Body: []byte(err.Error()), | ||||||
|  | 		}) | ||||||
|  | 		sock.Close() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	codec := newRpcPlusCodec(&msg, sock, cf) | ||||||
|  |  | ||||||
| 	// strip our headers | 	// strip our headers | ||||||
| 	hdr := make(map[string]string) | 	hdr := make(map[string]string) | ||||||
| @@ -55,11 +68,21 @@ func (s *rpcServer) accept(sock transport.Socket) { | |||||||
| 	ctx := c.WithMetadata(context.Background(), hdr) | 	ctx := c.WithMetadata(context.Background(), hdr) | ||||||
| 	// TODO: needs better error handling | 	// TODO: needs better error handling | ||||||
| 	if err := s.rpc.ServeRequestWithContext(ctx, codec); err != nil { | 	if err := s.rpc.ServeRequestWithContext(ctx, codec); err != nil { | ||||||
| 		log.Errorf("Unexpected error servinc request, closing socket: %v", err) | 		log.Errorf("Unexpected error serving request, closing socket: %v", err) | ||||||
| 		sock.Close() | 		sock.Close() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *rpcServer) codecFunc(contentType string) (codecFunc, error) { | ||||||
|  | 	if cf, ok := s.opts.codecs[contentType]; ok { | ||||||
|  | 		return codecWrap(cf), nil | ||||||
|  | 	} | ||||||
|  | 	if cf, ok := defaultCodecs[contentType]; ok { | ||||||
|  | 		return cf, nil | ||||||
|  | 	} | ||||||
|  | 	return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *rpcServer) Config() options { | func (s *rpcServer) Config() options { | ||||||
| 	s.RLock() | 	s.RLock() | ||||||
| 	opts := s.opts | 	opts := s.opts | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user