From 4cb41721f15a1e8fe4eda35aaabc21b76bc48e7d Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Tue, 8 Jan 2019 15:38:25 +0000 Subject: [PATCH 1/2] further codec changes --- client/rpc_client.go | 2 +- client/rpc_codec.go | 12 +++---- client/rpc_stream.go | 2 +- cmd/cmd.go | 2 +- codec/codec.go | 2 +- codec/jsonrpc/client.go | 10 +++--- codec/jsonrpc/server.go | 46 +++++++------------------- codec/proto/proto.go | 9 +++++- codec/protorpc/protorpc.go | 18 ++++++++--- server/rpc_codec.go | 58 +++++++++++++++++++++------------ server/rpc_codec_test.go | 11 +++---- server/rpc_router.go | 66 +++++++++++++++++++++----------------- server/rpc_stream.go | 22 +++++++------ 13 files changed, 139 insertions(+), 121 deletions(-) diff --git a/client/rpc_client.go b/client/rpc_client.go index f229b495..04be269e 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -97,7 +97,7 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp request: req, closed: make(chan bool), codec: newRpcCodec(msg, c, cf), - seq: seq, + seq: fmt.Sprintf("%v", seq), } defer stream.Close() diff --git a/client/rpc_codec.go b/client/rpc_codec.go index 7b0a8628..d4ca0702 100644 --- a/client/rpc_codec.go +++ b/client/rpc_codec.go @@ -4,7 +4,6 @@ import ( "bytes" errs "errors" "fmt" - "strconv" "github.com/micro/go-micro/codec" raw "github.com/micro/go-micro/codec/bytes" @@ -55,13 +54,13 @@ type clientCodec interface { type request struct { Service string ServiceMethod string // format: "Service.Method" - Seq uint64 // sequence number chosen by client + Seq string // sequence number chosen by client next *request // for free list in Server } type response struct { ServiceMethod string // echoes that of the Request - Seq uint64 // echoes that of the request + Seq string // echoes that of the request Error string // error, if any. next *response // for free list in Server } @@ -115,7 +114,7 @@ func (c *rpcCodec) Write(req *request, body interface{}) error { Method: req.ServiceMethod, Type: codec.Request, Header: map[string]string{ - "X-Micro-Id": fmt.Sprintf("%d", req.Seq), + "X-Micro-Id": fmt.Sprintf("%v", req.Seq), "X-Micro-Service": req.Service, "X-Micro-Method": req.ServiceMethod, }, @@ -161,9 +160,8 @@ func (c *rpcCodec) Read(r *response, b interface{}) error { r.ServiceMethod = me.Header["X-Micro-Method"] } - if me.Id == 0 && len(me.Header["X-Micro-Id"]) > 0 { - id, _ := strconv.ParseInt(me.Header["X-Micro-Id"], 10, 64) - r.Seq = uint64(id) + if len(me.Id) == 0 { + r.Seq = me.Header["X-Micro-Id"] } if err != nil { diff --git a/client/rpc_stream.go b/client/rpc_stream.go index 3ad9bccf..275d9133 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -9,7 +9,7 @@ import ( // Implements the streamer interface type rpcStream struct { sync.RWMutex - seq uint64 + seq string closed chan bool err error request Request diff --git a/cmd/cmd.go b/cmd/cmd.go index 12cbb874..7f08ffd7 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -181,7 +181,7 @@ var ( "default": selector.NewSelector, "dns": dns.NewSelector, "cache": selector.NewSelector, - "static": static.NewSelector, + "static": static.NewSelector, } DefaultServers = map[string]func(...server.Option) server.Server{ diff --git a/codec/codec.go b/codec/codec.go index edc659ff..7909d3f8 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -34,7 +34,7 @@ type Codec interface { // the communication, likely followed by the body. // In the case of an error, body may be nil. type Message struct { - Id uint64 + Id string Type MessageType Target string Method string diff --git a/codec/jsonrpc/client.go b/codec/jsonrpc/client.go index ebf8be91..f5ec5636 100644 --- a/codec/jsonrpc/client.go +++ b/codec/jsonrpc/client.go @@ -19,17 +19,17 @@ type clientCodec struct { resp clientResponse sync.Mutex - pending map[uint64]string + pending map[interface{}]string } type clientRequest struct { Method string `json:"method"` Params [1]interface{} `json:"params"` - ID uint64 `json:"id"` + ID interface{} `json:"id"` } type clientResponse struct { - ID uint64 `json:"id"` + ID interface{} `json:"id"` Result *json.RawMessage `json:"result"` Error interface{} `json:"error"` } @@ -39,7 +39,7 @@ func newClientCodec(conn io.ReadWriteCloser) *clientCodec { dec: json.NewDecoder(conn), enc: json.NewEncoder(conn), c: conn, - pending: make(map[uint64]string), + pending: make(map[interface{}]string), } } @@ -71,7 +71,7 @@ func (c *clientCodec) ReadHeader(m *codec.Message) error { c.Unlock() m.Error = "" - m.Id = c.resp.ID + m.Id = fmt.Sprintf("%v", c.resp.ID) if c.resp.Error != nil { x, ok := c.resp.Error.(string) if !ok { diff --git a/codec/jsonrpc/server.go b/codec/jsonrpc/server.go index bf78ca43..53f681ef 100644 --- a/codec/jsonrpc/server.go +++ b/codec/jsonrpc/server.go @@ -2,9 +2,8 @@ package jsonrpc import ( "encoding/json" - "errors" + "fmt" "io" - "sync" "github.com/micro/go-micro/codec" ) @@ -17,30 +16,25 @@ type serverCodec struct { // temporary work space req serverRequest resp serverResponse - - sync.Mutex - seq uint64 - pending map[uint64]*json.RawMessage } type serverRequest struct { Method string `json:"method"` Params *json.RawMessage `json:"params"` - ID *json.RawMessage `json:"id"` + ID interface{} `json:"id"` } type serverResponse struct { - ID *json.RawMessage `json:"id"` - Result interface{} `json:"result"` - Error interface{} `json:"error"` + ID interface{} `json:"id"` + Result interface{} `json:"result"` + Error interface{} `json:"error"` } func newServerCodec(conn io.ReadWriteCloser) *serverCodec { return &serverCodec{ - dec: json.NewDecoder(conn), - enc: json.NewEncoder(conn), - c: conn, - pending: make(map[uint64]*json.RawMessage), + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, } } @@ -50,7 +44,7 @@ func (r *serverRequest) reset() { *r.Params = (*r.Params)[0:0] } if r.ID != nil { - *r.ID = (*r.ID)[0:0] + r.ID = nil } } @@ -60,14 +54,8 @@ func (c *serverCodec) ReadHeader(m *codec.Message) error { return err } m.Method = c.req.Method - - c.Lock() - c.seq++ - c.pending[c.seq] = c.req.ID + m.Id = fmt.Sprintf("%v", c.req.ID) c.req.ID = nil - m.Id = c.seq - c.Unlock() - return nil } @@ -84,19 +72,7 @@ var null = json.RawMessage([]byte("null")) func (c *serverCodec) Write(m *codec.Message, x interface{}) error { var resp serverResponse - c.Lock() - b, ok := c.pending[m.Id] - if !ok { - c.Unlock() - return errors.New("invalid sequence number in response") - } - c.Unlock() - - if b == nil { - // Invalid request so no id. Use JSON null. - b = &null - } - resp.ID = b + resp.ID = m.Id resp.Result = x if m.Error == "" { resp.Error = nil diff --git a/codec/proto/proto.go b/codec/proto/proto.go index 339c08c1..b70b4c84 100644 --- a/codec/proto/proto.go +++ b/codec/proto/proto.go @@ -22,11 +22,18 @@ func (c *Codec) ReadBody(b interface{}) error { if err != nil { return err } + if b == nil { + return nil + } return proto.Unmarshal(buf, b.(proto.Message)) } func (c *Codec) Write(m *codec.Message, b interface{}) error { - buf, err := proto.Marshal(b.(proto.Message)) + p, ok := b.(proto.Message) + if !ok { + return nil + } + buf, err := proto.Marshal(p) if err != nil { return err } diff --git a/codec/protorpc/protorpc.go b/codec/protorpc/protorpc.go index 111026c2..de05552f 100644 --- a/codec/protorpc/protorpc.go +++ b/codec/protorpc/protorpc.go @@ -5,6 +5,7 @@ import ( "bytes" "fmt" "io" + "strconv" "sync" "github.com/golang/protobuf/proto" @@ -31,13 +32,22 @@ func (c *protoCodec) String() string { return "proto-rpc" } +func id(id string) *uint64 { + p, err := strconv.ParseInt(id, 10, 64) + if err != nil { + p = 0 + } + i := uint64(p) + return &i +} + func (c *protoCodec) Write(m *codec.Message, b interface{}) error { switch m.Type { case codec.Request: c.Lock() defer c.Unlock() // This is protobuf, of course we copy it. - pbr := &Request{ServiceMethod: &m.Method, Seq: &m.Id} + pbr := &Request{ServiceMethod: &m.Method, Seq: id(m.Id)} data, err := proto.Marshal(pbr) if err != nil { return err @@ -63,7 +73,7 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error { case codec.Response: c.Lock() defer c.Unlock() - rtmp := &Response{ServiceMethod: &m.Method, Seq: &m.Id, Error: &m.Error} + rtmp := &Response{ServiceMethod: &m.Method, Seq: id(m.Id), Error: &m.Error} data, err := proto.Marshal(rtmp) if err != nil { return err @@ -117,7 +127,7 @@ func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { return err } m.Method = rtmp.GetServiceMethod() - m.Id = rtmp.GetSeq() + m.Id = fmt.Sprintf("%d", rtmp.GetSeq()) case codec.Response: data, err := ReadNetString(c.rwc) if err != nil { @@ -129,7 +139,7 @@ func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { return err } m.Method = rtmp.GetServiceMethod() - m.Id = rtmp.GetSeq() + m.Id = fmt.Sprintf("%d", rtmp.GetSeq()) m.Error = rtmp.GetError() case codec.Publication: _, err := io.Copy(c.buf, c.rwc) diff --git a/server/rpc_codec.go b/server/rpc_codec.go index a9e5e0b7..a4392445 100644 --- a/server/rpc_codec.go +++ b/server/rpc_codec.go @@ -2,8 +2,6 @@ package server import ( "bytes" - "fmt" - "strconv" "github.com/micro/go-micro/codec" raw "github.com/micro/go-micro/codec/bytes" @@ -19,6 +17,7 @@ import ( type rpcCodec struct { socket transport.Socket codec codec.Codec + first bool req *transport.Message buf *readWriteCloser @@ -65,12 +64,13 @@ func (rwc *readWriteCloser) Close() error { return nil } -func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) serverCodec { +func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec { rwc := &readWriteCloser{ rbuf: bytes.NewBuffer(req.Body), wbuf: bytes.NewBuffer(nil), } r := &rpcCodec{ + first: true, buf: rwc, codec: c(rwc), req: req, @@ -79,36 +79,43 @@ func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCod return r } -func (c *rpcCodec) ReadHeader(r *request, first bool) error { +func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { m := codec.Message{Header: c.req.Header} - if !first { + // if its a follow on request read it + if !c.first { var tm transport.Message + + // read off the socket if err := c.socket.Recv(&tm); err != nil { return err } + // reset the read buffer c.buf.rbuf.Reset() + + // write the body to the buffer if _, err := c.buf.rbuf.Write(tm.Body); err != nil { return err } + // set the message header m.Header = tm.Header } + // no longer first read + c.first = false + // set some internal things m.Target = m.Header["X-Micro-Service"] m.Method = m.Header["X-Micro-Method"] - - // set id - if len(m.Header["X-Micro-Id"]) > 0 { - id, _ := strconv.ParseInt(m.Header["X-Micro-Id"], 10, 64) - m.Id = uint64(id) - } + m.Id = m.Header["X-Micro-Id"] // read header via codec err := c.codec.ReadHeader(&m, codec.Request) - r.ServiceMethod = m.Method - r.Seq = m.Id + + // set the method/id + r.Method = m.Method + r.Id = m.Id return err } @@ -117,21 +124,28 @@ func (c *rpcCodec) ReadBody(b interface{}) error { return c.codec.ReadBody(b) } -func (c *rpcCodec) Write(r *response, body interface{}, last bool) error { +func (c *rpcCodec) Write(r *codec.Message, body interface{}) error { c.buf.wbuf.Reset() + + // create a new message m := &codec.Message{ - Method: r.ServiceMethod, - Id: r.Seq, + Method: r.Method, + Id: r.Id, Error: r.Error, - Type: codec.Response, + Type: r.Type, Header: map[string]string{ - "X-Micro-Id": fmt.Sprintf("%d", r.Seq), - "X-Micro-Method": r.ServiceMethod, + "X-Micro-Id": r.Id, + "X-Micro-Method": r.Method, "X-Micro-Error": r.Error, + "Content-Type": c.req.Header["Content-Type"], }, } + + // write to the body if err := c.codec.Write(m, body); err != nil { c.buf.wbuf.Reset() + + // write an error if it failed m.Error = errors.Wrapf(err, "Unable to encode body").Error() m.Header["X-Micro-Error"] = m.Error if err := c.codec.Write(m, nil); err != nil { @@ -139,7 +153,7 @@ func (c *rpcCodec) Write(r *response, body interface{}, last bool) error { } } - m.Header["Content-Type"] = c.req.Header["Content-Type"] + // send on the socket return c.socket.Send(&transport.Message{ Header: m.Header, Body: c.buf.wbuf.Bytes(), @@ -151,3 +165,7 @@ func (c *rpcCodec) Close() error { c.codec.Close() return c.socket.Close() } + +func (c *rpcCodec) String() string { + return "rpc" +} diff --git a/server/rpc_codec_test.go b/server/rpc_codec_test.go index 977c27c0..59035e3e 100644 --- a/server/rpc_codec_test.go +++ b/server/rpc_codec_test.go @@ -47,12 +47,11 @@ func TestCodecWriteError(t *testing.T) { socket: socket, } - err := c.Write(&response{ - ServiceMethod: "Service.Method", - Seq: 0, - Error: "", - next: nil, - }, "body", false) + err := c.Write(&codec.Message{ + Method: "Service.Method", + Id: "0", + Error: "", + }, "body") if err != nil { t.Fatalf(`Expected Write to fail; got "%+v" instead`, err) diff --git a/server/rpc_router.go b/server/rpc_router.go index a3ab1fe6..fa38fb79 100644 --- a/server/rpc_router.go +++ b/server/rpc_router.go @@ -17,6 +17,7 @@ import ( "unicode/utf8" "github.com/micro/go-log" + "github.com/micro/go-micro/codec" ) var ( @@ -48,16 +49,13 @@ type service struct { } type request struct { - ServiceMethod string // format: "Service.Method" - Seq uint64 // sequence number chosen by client - next *request // for free list in Server + msg *codec.Message + next *request // for free list in Server } type response struct { - ServiceMethod string // echoes that of the Request - Seq uint64 // echoes that of the request - Error string // error, if any. - next *response // for free list in Server + msg *codec.Message + next *response // for free list in Server } // router represents an RPC router. @@ -215,30 +213,34 @@ func (router *router) Handle(h Handler) error { return nil } -func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, codec serverCodec, errmsg string, last bool) (err error) { +func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Codec, errmsg string, last bool) (err error) { + msg := new(codec.Message) + msg.Type = codec.Response resp := router.getResponse() + resp.msg = msg + // Encode the response header - resp.ServiceMethod = req.ServiceMethod + resp.msg.Method = req.msg.Method if errmsg != "" { - resp.Error = errmsg + resp.msg.Error = errmsg reply = invalidRequest } - resp.Seq = req.Seq + resp.msg.Id = req.msg.Id sending.Lock() - err = codec.Write(resp, reply, last) + err = cc.Write(resp.msg, reply) sending.Unlock() router.freeResponse(resp) return err } -func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec, ct string) { +func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec codec.Codec, ct string) { function := mtype.method.Func var returnValues []reflect.Value r := &rpcRequest{ service: router.name, contentType: ct, - method: req.ServiceMethod, + method: req.msg.Method, } if !mtype.stream { @@ -282,7 +284,7 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, context: ctx, codec: codec, request: r, - seq: req.Seq, + id: req.msg.Id, } // Invoke the method, providing a new value for the reply. @@ -326,21 +328,21 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value { return reflect.Zero(m.ContextType) } -func (router *router) ServeRequest(ctx context.Context, codec serverCodec, ct string) error { +func (router *router) ServeRequest(ctx context.Context, cc codec.Codec, ct string) error { sending := new(sync.Mutex) - service, mtype, req, argv, replyv, keepReading, err := router.readRequest(codec) + service, mtype, req, argv, replyv, keepReading, err := router.readRequest(cc) if err != nil { if !keepReading { return err } // send a response if we actually managed to read a header. if req != nil { - router.sendResponse(sending, req, invalidRequest, codec, err.Error(), true) + router.sendResponse(sending, req, invalidRequest, cc, err.Error(), true) router.freeRequest(req) } return err } - service.call(ctx, router, sending, mtype, req, argv, replyv, codec, ct) + service.call(ctx, router, sending, mtype, req, argv, replyv, cc, ct) return nil } @@ -384,19 +386,19 @@ func (router *router) freeResponse(resp *response) { router.respLock.Unlock() } -func (router *router) readRequest(codec serverCodec) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { - service, mtype, req, keepReading, err = router.readHeader(codec) +func (router *router) readRequest(cc codec.Codec) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { + service, mtype, req, keepReading, err = router.readHeader(cc) if err != nil { if !keepReading { return } // discard body - codec.ReadBody(nil) + cc.ReadBody(nil) return } // is it a streaming request? then we don't read the body if mtype.stream { - codec.ReadBody(nil) + cc.ReadBody(nil) return } @@ -409,7 +411,7 @@ func (router *router) readRequest(codec serverCodec) (service *service, mtype *m argIsValue = true } // argv guaranteed to be a pointer now. - if err = codec.ReadBody(argv.Interface()); err != nil { + if err = cc.ReadBody(argv.Interface()); err != nil { return } if argIsValue { @@ -422,10 +424,14 @@ func (router *router) readRequest(codec serverCodec) (service *service, mtype *m return } -func (router *router) readHeader(codec serverCodec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { +func (router *router) readHeader(cc codec.Codec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { // Grab the request header. + msg := new(codec.Message) + msg.Type = codec.Request req = router.getRequest() - err = codec.ReadHeader(req, true) + req.msg = msg + + err = cc.ReadHeader(msg, msg.Type) if err != nil { req = nil if err == io.EOF || err == io.ErrUnexpectedEOF { @@ -439,9 +445,9 @@ func (router *router) readHeader(codec serverCodec) (service *service, mtype *me // we can still recover and move on to the next request. keepReading = true - serviceMethod := strings.Split(req.ServiceMethod, ".") + serviceMethod := strings.Split(req.msg.Method, ".") if len(serviceMethod) != 2 { - err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) + err = errors.New("rpc: service/method request ill-formed: " + req.msg.Method) return } // Look up the request. @@ -449,12 +455,12 @@ func (router *router) readHeader(codec serverCodec) (service *service, mtype *me service = router.serviceMap[serviceMethod[0]] router.mu.Unlock() if service == nil { - err = errors.New("rpc: can't find service " + req.ServiceMethod) + err = errors.New("rpc: can't find service " + req.msg.Method) return } mtype = service.method[serviceMethod[1]] if mtype == nil { - err = errors.New("rpc: can't find method " + req.ServiceMethod) + err = errors.New("rpc: can't find method " + req.msg.Method) } return } diff --git a/server/rpc_stream.go b/server/rpc_stream.go index 645f288a..efbf0d55 100644 --- a/server/rpc_stream.go +++ b/server/rpc_stream.go @@ -3,16 +3,18 @@ package server import ( "context" "sync" + + "github.com/micro/go-micro/codec" ) // Implements the Streamer interface type rpcStream struct { sync.RWMutex - seq uint64 + id string closed bool err error request Request - codec serverCodec + codec codec.Codec context context.Context } @@ -28,28 +30,30 @@ func (r *rpcStream) Send(msg interface{}) error { r.Lock() defer r.Unlock() - resp := response{ - ServiceMethod: r.request.Method(), - Seq: r.seq, + resp := codec.Message{ + Method: r.request.Method(), + Id: r.id, + Type: codec.Response, } - return r.codec.Write(&resp, msg, false) + return r.codec.Write(&resp, msg) } func (r *rpcStream) Recv(msg interface{}) error { r.Lock() defer r.Unlock() - req := request{} + req := new(codec.Message) + req.Type = codec.Request - if err := r.codec.ReadHeader(&req, false); err != nil { + if err := r.codec.ReadHeader(req, req.Type); err != nil { // discard body r.codec.ReadBody(nil) return err } // we need to stay up to date with sequence numbers - r.seq = req.Seq + r.id = req.Id return r.codec.ReadBody(msg) } From f46828be33834dd683f995e7a0b3fed2f1facf44 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Tue, 8 Jan 2019 20:32:47 +0000 Subject: [PATCH 2/2] Add Router interface --- server/options.go | 11 +++++++++++ server/rpc_router.go | 8 ++++---- server/rpc_server.go | 16 +++++++++++++--- server/server.go | 2 +- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/server/options.go b/server/options.go index 5b9aa98e..f1355122 100644 --- a/server/options.go +++ b/server/options.go @@ -25,8 +25,12 @@ type Options struct { HdlrWrappers []HandlerWrapper SubWrappers []SubscriberWrapper + // The register expiry time RegisterTTL time.Duration + // The router for requests + Router Router + // Debug Handler which can be set by a user DebugHandler debug.DebugHandler @@ -164,6 +168,13 @@ func RegisterTTL(t time.Duration) Option { } } +// WithRouter sets the request router +func WithRouter(r Router) Option { + return func(o *Options) { + o.Router = r + } +} + // Wait tells the server to wait for requests to finish before exiting func Wait(b bool) Option { return func(o *Options) { diff --git a/server/rpc_router.go b/server/rpc_router.go index fa38fb79..cd25994f 100644 --- a/server/rpc_router.go +++ b/server/rpc_router.go @@ -233,13 +233,13 @@ func (router *router) sendResponse(sending sync.Locker, req *request, reply inte return err } -func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec codec.Codec, ct string) { +func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec codec.Codec) { function := mtype.method.Func var returnValues []reflect.Value r := &rpcRequest{ service: router.name, - contentType: ct, + contentType: req.msg.Header["Content-Type"], method: req.msg.Method, } @@ -328,7 +328,7 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value { return reflect.Zero(m.ContextType) } -func (router *router) ServeRequest(ctx context.Context, cc codec.Codec, ct string) error { +func (router *router) ServeRequest(ctx context.Context, cc codec.Codec) error { sending := new(sync.Mutex) service, mtype, req, argv, replyv, keepReading, err := router.readRequest(cc) if err != nil { @@ -342,7 +342,7 @@ func (router *router) ServeRequest(ctx context.Context, cc codec.Codec, ct strin } return err } - service.call(ctx, router, sending, mtype, req, argv, replyv, cc, ct) + service.call(ctx, router, sending, mtype, req, argv, replyv, cc) return nil } diff --git a/server/rpc_server.go b/server/rpc_server.go index adbb0b0e..6d73f63b 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -45,7 +45,8 @@ func newRpcServer(opts ...Option) Server { } } -func (s *rpcServer) accept(sock transport.Socket) { +// ServeConn serves a single connection +func (s *rpcServer) ServeConn(sock transport.Socket) { defer func() { // close socket sock.Close() @@ -92,6 +93,7 @@ func (s *rpcServer) accept(sock transport.Socket) { // no content type if len(ct) == 0 { + msg.Header["Content-Type"] = DefaultContentType ct = DefaultContentType } @@ -111,8 +113,16 @@ func (s *rpcServer) accept(sock transport.Socket) { // create the internal server codec codec := newRpcCodec(&msg, sock, cf) + // set router + var r Router + r = s.router + + if s.opts.Router != nil { + r = s.opts.Router + } + // TODO: needs better error handling - if err := s.router.ServeRequest(ctx, codec, ct); err != nil { + if err := r.ServeRequest(ctx, codec); err != nil { s.wg.Done() log.Logf("Unexpected error serving request, closing socket: %v", err) return @@ -402,7 +412,7 @@ func (s *rpcServer) Start() error { go func() { for { - err := ts.Accept(s.accept) + err := ts.Accept(s.ServeConn) // check if we're supposed to exit select { diff --git a/server/server.go b/server/server.go index 3e72b6a4..5e454824 100644 --- a/server/server.go +++ b/server/server.go @@ -30,7 +30,7 @@ type Server interface { // Router handle serving messages type Router interface { - ServeCodec(context.Context, codec.Codec) error + ServeRequest(context.Context, codec.Codec) error } // Message is an async message interface