diff --git a/server/rpc_service.go b/server/rpc_router.go similarity index 81% rename from server/rpc_service.go rename to server/rpc_router.go index 14b667c9..6e67265a 100644 --- a/server/rpc_service.go +++ b/server/rpc_router.go @@ -60,8 +60,8 @@ type response struct { next *response // for free list in Server } -// server represents an RPC Server. -type server struct { +// router represents an RPC router. +type router struct { name string mu sync.Mutex // protects the serviceMap serviceMap map[string]*service @@ -72,6 +72,14 @@ type server struct { hdlrWrappers []HandlerWrapper } +func newRpcRouter(opts Options) *router { + return &router{ + name: opts.Name, + hdlrWrappers: opts.HdlrWrappers, + serviceMap: make(map[string]*service), + } +} + // Is this an exported - upper case - name? func isExported(name string) bool { rune, _ := utf8.DecodeRuneInString(name) @@ -158,11 +166,11 @@ func prepareMethod(method reflect.Method) *methodType { return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} } -func (server *server) register(rcvr interface{}) error { - server.mu.Lock() - defer server.mu.Unlock() - if server.serviceMap == nil { - server.serviceMap = make(map[string]*service) +func (router *router) register(rcvr interface{}) error { + router.mu.Lock() + defer router.mu.Unlock() + if router.serviceMap == nil { + router.serviceMap = make(map[string]*service) } s := new(service) s.typ = reflect.TypeOf(rcvr) @@ -176,7 +184,7 @@ func (server *server) register(rcvr interface{}) error { log.Log(s) return errors.New(s) } - if _, present := server.serviceMap[sname]; present { + if _, present := router.serviceMap[sname]; present { return errors.New("rpc: service already defined: " + sname) } s.name = sname @@ -195,12 +203,12 @@ func (server *server) register(rcvr interface{}) error { log.Log(s) return errors.New(s) } - server.serviceMap[s.name] = s + router.serviceMap[s.name] = s return nil } -func (server *server) sendResponse(sending sync.Locker, req *request, reply interface{}, codec serverCodec, errmsg string, last bool) (err error) { - resp := server.getResponse() +func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, codec serverCodec, errmsg string, last bool) (err error) { + resp := router.getResponse() // Encode the response header resp.ServiceMethod = req.ServiceMethod if errmsg != "" { @@ -211,16 +219,16 @@ func (server *server) sendResponse(sending sync.Locker, req *request, reply inte sending.Lock() err = codec.WriteResponse(resp, reply, last) sending.Unlock() - server.freeResponse(resp) + router.freeResponse(resp) return err } -func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec, ct string) { +func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec, ct string) { function := mtype.method.Func var returnValues []reflect.Value r := &rpcRequest{ - service: server.name, + service: router.name, contentType: ct, method: req.ServiceMethod, } @@ -239,8 +247,8 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, return nil } - for i := len(server.hdlrWrappers); i > 0; i-- { - fn = server.hdlrWrappers[i-1](fn) + for i := len(router.hdlrWrappers); i > 0; i-- { + fn = router.hdlrWrappers[i-1](fn) } errmsg := "" @@ -249,11 +257,11 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, errmsg = err.Error() } - err = server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true) + err = router.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true) if err != nil { log.Log("rpc call: unable to send response: ", err) } - server.freeRequest(req) + router.freeRequest(req) return } @@ -284,8 +292,8 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, } } - for i := len(server.hdlrWrappers); i > 0; i-- { - fn = server.hdlrWrappers[i-1](fn) + for i := len(router.hdlrWrappers); i > 0; i-- { + fn = router.hdlrWrappers[i-1](fn) } // client.Stream request @@ -299,8 +307,8 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, // this is the last packet, we don't do anything with // the error here (well sendStreamResponse will log it // already) - server.sendResponse(sending, req, nil, codec, errmsg, true) - server.freeRequest(req) + router.sendResponse(sending, req, nil, codec, errmsg, true) + router.freeRequest(req) } func (m *methodType) prepareContext(ctx context.Context) reflect.Value { @@ -310,66 +318,66 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value { return reflect.Zero(m.ContextType) } -func (server *server) serveRequest(ctx context.Context, codec serverCodec, ct string) error { +func (router *router) serveRequest(ctx context.Context, codec serverCodec, ct string) error { sending := new(sync.Mutex) - service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) + service, mtype, req, argv, replyv, keepReading, err := router.readRequest(codec) if err != nil { if !keepReading { return err } // send a response if we actually managed to read a header. if req != nil { - server.sendResponse(sending, req, invalidRequest, codec, err.Error(), true) - server.freeRequest(req) + router.sendResponse(sending, req, invalidRequest, codec, err.Error(), true) + router.freeRequest(req) } return err } - service.call(ctx, server, sending, mtype, req, argv, replyv, codec, ct) + service.call(ctx, router, sending, mtype, req, argv, replyv, codec, ct) return nil } -func (server *server) getRequest() *request { - server.reqLock.Lock() - req := server.freeReq +func (router *router) getRequest() *request { + router.reqLock.Lock() + req := router.freeReq if req == nil { req = new(request) } else { - server.freeReq = req.next + router.freeReq = req.next *req = request{} } - server.reqLock.Unlock() + router.reqLock.Unlock() return req } -func (server *server) freeRequest(req *request) { - server.reqLock.Lock() - req.next = server.freeReq - server.freeReq = req - server.reqLock.Unlock() +func (router *router) freeRequest(req *request) { + router.reqLock.Lock() + req.next = router.freeReq + router.freeReq = req + router.reqLock.Unlock() } -func (server *server) getResponse() *response { - server.respLock.Lock() - resp := server.freeResp +func (router *router) getResponse() *response { + router.respLock.Lock() + resp := router.freeResp if resp == nil { resp = new(response) } else { - server.freeResp = resp.next + router.freeResp = resp.next *resp = response{} } - server.respLock.Unlock() + router.respLock.Unlock() return resp } -func (server *server) freeResponse(resp *response) { - server.respLock.Lock() - resp.next = server.freeResp - server.freeResp = resp - server.respLock.Unlock() +func (router *router) freeResponse(resp *response) { + router.respLock.Lock() + resp.next = router.freeResp + router.freeResp = resp + router.respLock.Unlock() } -func (server *server) readRequest(codec serverCodec) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { - service, mtype, req, keepReading, err = server.readRequestHeader(codec) +func (router *router) readRequest(codec serverCodec) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { + service, mtype, req, keepReading, err = router.readRequestHeader(codec) if err != nil { if !keepReading { return @@ -406,16 +414,16 @@ func (server *server) readRequest(codec serverCodec) (service *service, mtype *m return } -func (server *server) readRequestHeader(codec serverCodec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { +func (router *router) readRequestHeader(codec serverCodec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { // Grab the request header. - req = server.getRequest() + req = router.getRequest() err = codec.ReadRequestHeader(req, true) if err != nil { req = nil if err == io.EOF || err == io.ErrUnexpectedEOF { return } - err = errors.New("rpc: server cannot decode request: " + err.Error()) + err = errors.New("rpc: router cannot decode request: " + err.Error()) return } @@ -429,9 +437,9 @@ func (server *server) readRequestHeader(codec serverCodec) (service *service, mt return } // Look up the request. - server.mu.Lock() - service = server.serviceMap[serviceMethod[0]] - server.mu.Unlock() + router.mu.Lock() + service = router.serviceMap[serviceMethod[0]] + router.mu.Unlock() if service == nil { err = errors.New("rpc: can't find service " + req.ServiceMethod) return diff --git a/server/rpc_server.go b/server/rpc_server.go index 171848b0..984e2b70 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -21,8 +21,8 @@ import ( ) type rpcServer struct { - rpc *server - exit chan chan error + router *router + exit chan chan error sync.RWMutex opts Options @@ -37,12 +37,8 @@ type rpcServer struct { func newRpcServer(opts ...Option) Server { options := newOptions(opts...) return &rpcServer{ - opts: options, - rpc: &server{ - name: options.Name, - serviceMap: make(map[string]*service), - hdlrWrappers: options.HdlrWrappers, - }, + opts: options, + router: newRpcRouter(options), handlers: make(map[string]Handler), subscribers: make(map[*subscriber][]broker.Subscriber), exit: make(chan chan error), @@ -111,7 +107,7 @@ func (s *rpcServer) accept(sock transport.Socket) { } // TODO: needs better error handling - if err := s.rpc.serveRequest(ctx, codec, ct); err != nil { + if err := s.router.serveRequest(ctx, codec, ct); err != nil { s.wg.Done() log.Logf("Unexpected error serving request, closing socket: %v", err) return @@ -142,12 +138,12 @@ func (s *rpcServer) Init(opts ...Option) error { for _, opt := range opts { opt(&s.opts) } - // update internal server - s.rpc = &server{ - name: s.opts.Name, - serviceMap: s.rpc.serviceMap, - hdlrWrappers: s.opts.HdlrWrappers, - } + + // update router + r := newRpcRouter(s.opts) + r.serviceMap = s.router.serviceMap + s.router = r + s.Unlock() return nil } @@ -160,7 +156,7 @@ func (s *rpcServer) Handle(h Handler) error { s.Lock() defer s.Unlock() - if err := s.rpc.register(h.Handler()); err != nil { + if err := s.router.register(h.Handler()); err != nil { return err } diff --git a/server/server.go b/server/server.go index cdbb4fca..da839f9c 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,11 @@ type Server interface { String() string } +// Router handle serving messages +type Router interface { + ServeRequest(context.Context, Stream) error +} + // Message is an async message interface type Message interface { Topic() string @@ -97,6 +102,7 @@ var ( DefaultVersion = "1.0.0" DefaultId = uuid.New().String() DefaultServer Server = newRpcServer() + DefaultRouter = newRpcRouter(newOptions()) ) // DefaultOptions returns config options for the default service