diff --git a/server/options.go b/server/options.go index ca17fb3c..ec4d281f 100644 --- a/server/options.go +++ b/server/options.go @@ -18,6 +18,7 @@ type options struct { advertise string id string version string + wrappers []Wrapper } func newOptions(opt ...Option) options { @@ -153,3 +154,10 @@ func Metadata(md map[string]string) Option { o.metadata = md } } + +// Adds a handler Wrapper to a list of options passed into the server +func Wrap(w Wrapper) Option { + return func(o *options) { + o.wrappers = append(o.wrappers, w) + } +} diff --git a/server/rpc_server.go b/server/rpc_server.go index 64579bf4..b998a21c 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -28,9 +28,13 @@ type rpcServer struct { } func newRpcServer(opts ...Option) Server { + options := newOptions(opts...) return &rpcServer{ - opts: newOptions(opts...), - rpc: newServer(), + opts: options, + rpc: &server{ + serviceMap: make(map[string]*service), + wrappers: options.wrappers, + }, handlers: make(map[string]Handler), subscribers: make(map[*subscriber][]broker.Subscriber), exit: make(chan chan error), diff --git a/server/rpcplus_server.go b/server/rpcplus_server.go index ca2ae660..ec13cc76 100644 --- a/server/rpcplus_server.go +++ b/server/rpcplus_server.go @@ -18,11 +18,8 @@ import ( "golang.org/x/net/context" ) -const ( - lastStreamResponseError = "EOS" -) - var ( + lastStreamResponseError = errors.New("EOS") // A value sent as a placeholder for the server's response value when the server // receives an invalid request. It is never decoded by the client since the Response // contains an error when it is used. @@ -43,10 +40,6 @@ type methodType struct { numCalls uint } -func (m *methodType) TakesContext() bool { - return m.ContextType != nil -} - func (m *methodType) NumCalls() (n uint) { m.Lock() n = m.numCalls @@ -82,10 +75,7 @@ type server struct { freeReq *request respLock sync.Mutex // protects freeResp freeResp *response -} - -func newServer() *server { - return &server{serviceMap: make(map[string]*service)} + wrappers []Wrapper } // Is this an exported - upper case - name? @@ -122,11 +112,6 @@ func prepareMethod(method reflect.Method) *methodType { } switch mtype.NumIn() { - case 3: - // normal method - argType = mtype.In(1) - replyType = mtype.In(2) - contextType = nil case 4: // method that takes a context argType = mtype.In(2) @@ -259,20 +244,27 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, var returnValues []reflect.Value if !mtype.stream { - - // Invoke the method, providing a new value for the reply. - if mtype.TakesContext() { + fn := func(ctx context.Context, req interface{}, rsp interface{}) error { returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, replyv}) - } else { - returnValues = function.Call([]reflect.Value{s.rcvr, argv, replyv}) + + // The return value for the method is an error. + if err := returnValues[0].Interface(); err != nil { + return err.(error) + } + + return nil } - // The return value for the method is an error. - errInter := returnValues[0].Interface() - errmsg := "" - if errInter != nil { - errmsg = errInter.(error).Error() + for i := len(server.wrappers); i > 0; i-- { + fn = server.wrappers[i-1](fn) } + + errmsg := "" + err := fn(ctx, argv.Interface(), replyv.Interface()) + if err != nil { + errmsg = err.Error() + } + server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true) server.freeRequest(req) return @@ -314,22 +306,28 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, } // Invoke the method, providing a new value for the reply. - if mtype.TakesContext() { + fn := func(ctx context.Context, req interface{}, rspFn interface{}) error { returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, reflect.ValueOf(sendReply)}) - } else { - returnValues = function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)}) + if err := returnValues[0].Interface(); err != nil { + // the function returned an error, we use that + return err.(error) + } else if lastError != nil { + // we had an error inside sendReply, we use that + return lastError + } else { + // no error, we send the special EOS error + return lastStreamResponseError + } + return nil } - errInter := returnValues[0].Interface() + + for i := len(server.wrappers); i > 0; i-- { + fn = server.wrappers[i-1](fn) + } + errmsg := "" - if errInter != nil { - // the function returned an error, we use that - errmsg = errInter.(error).Error() - } else if lastError != nil { - // we had an error inside sendReply, we use that - errmsg = lastError.Error() - } else { - // no error, we send the special EOS error - errmsg = lastStreamResponseError + if err := fn(ctx, argv.Interface(), reflect.ValueOf(sendReply).Interface()); err != nil { + errmsg = err.Error() } // this is the last packet, we don't do anything with diff --git a/server/server_wrapper.go b/server/server_wrapper.go new file mode 100644 index 00000000..3657846d --- /dev/null +++ b/server/server_wrapper.go @@ -0,0 +1,9 @@ +package server + +import ( + "golang.org/x/net/context" +) + +type HandlerFunc func(ctx context.Context, req interface{}, rsp interface{}) error + +type Wrapper func(HandlerFunc) HandlerFunc