diff --git a/server/options.go b/server/options.go index 5b9aa98e..f1355122 100644 --- a/server/options.go +++ b/server/options.go @@ -25,8 +25,12 @@ type Options struct { HdlrWrappers []HandlerWrapper SubWrappers []SubscriberWrapper + // The register expiry time RegisterTTL time.Duration + // The router for requests + Router Router + // Debug Handler which can be set by a user DebugHandler debug.DebugHandler @@ -164,6 +168,13 @@ func RegisterTTL(t time.Duration) Option { } } +// WithRouter sets the request router +func WithRouter(r Router) Option { + return func(o *Options) { + o.Router = r + } +} + // Wait tells the server to wait for requests to finish before exiting func Wait(b bool) Option { return func(o *Options) { diff --git a/server/rpc_router.go b/server/rpc_router.go index fa38fb79..cd25994f 100644 --- a/server/rpc_router.go +++ b/server/rpc_router.go @@ -233,13 +233,13 @@ func (router *router) sendResponse(sending sync.Locker, req *request, reply inte return err } -func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec codec.Codec, ct string) { +func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec codec.Codec) { function := mtype.method.Func var returnValues []reflect.Value r := &rpcRequest{ service: router.name, - contentType: ct, + contentType: req.msg.Header["Content-Type"], method: req.msg.Method, } @@ -328,7 +328,7 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value { return reflect.Zero(m.ContextType) } -func (router *router) ServeRequest(ctx context.Context, cc codec.Codec, ct string) error { +func (router *router) ServeRequest(ctx context.Context, cc codec.Codec) error { sending := new(sync.Mutex) service, mtype, req, argv, replyv, keepReading, err := router.readRequest(cc) if err != nil { @@ -342,7 +342,7 @@ func (router *router) ServeRequest(ctx context.Context, cc codec.Codec, ct strin } return err } - service.call(ctx, router, sending, mtype, req, argv, replyv, cc, ct) + service.call(ctx, router, sending, mtype, req, argv, replyv, cc) return nil } diff --git a/server/rpc_server.go b/server/rpc_server.go index adbb0b0e..6d73f63b 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -45,7 +45,8 @@ func newRpcServer(opts ...Option) Server { } } -func (s *rpcServer) accept(sock transport.Socket) { +// ServeConn serves a single connection +func (s *rpcServer) ServeConn(sock transport.Socket) { defer func() { // close socket sock.Close() @@ -92,6 +93,7 @@ func (s *rpcServer) accept(sock transport.Socket) { // no content type if len(ct) == 0 { + msg.Header["Content-Type"] = DefaultContentType ct = DefaultContentType } @@ -111,8 +113,16 @@ func (s *rpcServer) accept(sock transport.Socket) { // create the internal server codec codec := newRpcCodec(&msg, sock, cf) + // set router + var r Router + r = s.router + + if s.opts.Router != nil { + r = s.opts.Router + } + // TODO: needs better error handling - if err := s.router.ServeRequest(ctx, codec, ct); err != nil { + if err := r.ServeRequest(ctx, codec); err != nil { s.wg.Done() log.Logf("Unexpected error serving request, closing socket: %v", err) return @@ -402,7 +412,7 @@ func (s *rpcServer) Start() error { go func() { for { - err := ts.Accept(s.accept) + err := ts.Accept(s.ServeConn) // check if we're supposed to exit select { diff --git a/server/server.go b/server/server.go index 3e72b6a4..5e454824 100644 --- a/server/server.go +++ b/server/server.go @@ -30,7 +30,7 @@ type Server interface { // Router handle serving messages type Router interface { - ServeCodec(context.Context, codec.Codec) error + ServeRequest(context.Context, codec.Codec) error } // Message is an async message interface