From a46c9d395a6da7f33f8854ba25ac31922cd3fdb6 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Tue, 9 Mar 2021 14:19:59 +0300 Subject: [PATCH] fix for multiple server registration Signed-off-by: Vasiliy Tolstov --- handler.go | 64 +++++++++++++++++++++++++++++------------------------- http.go | 47 +++++++++++++++++++++++++++------------ 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/handler.go b/handler.go index 9b04c46..d1bcca1 100644 --- a/handler.go +++ b/handler.go @@ -34,13 +34,13 @@ type patHandler struct { } type httpHandler struct { - name string - opts server.HandlerOptions - sopts server.Options - eps []*register.Endpoint - hd interface{} - handlers map[string][]patHandler - errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int) + name string + opts server.HandlerOptions + sopts server.Options + eps []*register.Endpoint + hd interface{} + handlers map[string][]patHandler + //errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int) } func (h *httpHandler) newCodec(ct string) (codec.Codec, error) { @@ -66,14 +66,14 @@ func (h *httpHandler) Options() server.HandlerOptions { return h.opts } -func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := metadata.NewContext(r.Context(), nil) defer r.Body.Close() path := r.URL.Path if !strings.HasPrefix(path, "/") { - h.errorHandler(ctx, h, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest) + h.errorHandler(ctx, nil, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest) } ct := DefaultContentType @@ -83,7 +83,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { cf, err := h.newCodec(ct) if err != nil { - h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) + h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest) } components := strings.Split(path[1:], "/") @@ -91,7 +91,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var verb string idx := strings.LastIndex(components[l-1], ":") if idx == 0 { - h.errorHandler(ctx, h, w, r, fmt.Errorf("not found"), http.StatusNotFound) + h.errorHandler(ctx, nil, w, r, fmt.Errorf("not found"), http.StatusNotFound) return } if idx > 0 { @@ -102,19 +102,25 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { matches := make(map[string]interface{}) var match bool var hldr patHandler - for _, hldr = range h.handlers[r.Method] { - mp, err := hldr.pat.Match(components, verb) - if err == nil { - match = true - for k, v := range mp { - matches[k] = v + var handler *httpHandler + for _, hpat := range h.handlers { + handlertmp := hpat.(*httpHandler) + for _, hldrtmp := range handlertmp.handlers[r.Method] { + mp, err := hldrtmp.pat.Match(components, verb) + if err == nil { + match = true + for k, v := range mp { + matches[k] = v + } + hldr = hldrtmp + handler = handlertmp + break } - break } } if !match { - h.errorHandler(ctx, h, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound) + h.errorHandler(ctx, nil, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound) return } @@ -131,7 +137,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(r.URL.RawQuery) > 0 { umd, err := rflutil.URLMap(r.URL.RawQuery) if err != nil { - h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) + h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) } for k, v := range umd { matches[k] = v @@ -161,24 +167,24 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var returnValues []reflect.Value if err = cf.ReadBody(r.Body, argv.Interface()); err != nil && err != io.EOF { - h.errorHandler(ctx, h, w, r, err, http.StatusInternalServerError) + h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError) } matches = rflutil.FlattenMap(matches) if err = rflutil.MergeMap(argv.Interface(), matches); err != nil { - h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) + h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) return } b, err := cf.Marshal(argv.Interface()) if err != nil { - h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) + h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) return } hr := &rpcRequest{ codec: cf, - service: h.sopts.Name, + service: handler.sopts.Name, contentType: ct, method: fmt.Sprintf("%s.%s", hldr.name, hldr.mtype.method.Name), body: b, @@ -206,8 +212,8 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // wrap the handler func - for i := len(h.sopts.HdlrWrappers); i > 0; i-- { - fn = h.sopts.HdlrWrappers[i-1](fn) + for i := len(handler.sopts.HdlrWrappers); i > 0; i-- { + fn = handler.sopts.HdlrWrappers[i-1](fn) } if appErr := fn(ctx, hr, replyv.Interface()); appErr != nil { @@ -223,8 +229,8 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { b, err = cf.Marshal(replyv.Interface()) } - if err != nil && h.sopts.Logger.V(logger.ErrorLevel) { - h.sopts.Logger.Errorf(h.sopts.Context, "handler err: %v", err) + if err != nil && handler.sopts.Logger.V(logger.ErrorLevel) { + handler.sopts.Logger.Errorf(handler.sopts.Context, "handler err: %v", err) return } @@ -238,7 +244,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if scode != 0 { w.WriteHeader(scode) } else { - h.sopts.Logger.Warn(h.sopts.Context, "response code not set in handler via SetRspCode(ctx, http.StatusXXX)") + handler.sopts.Logger.Warn(handler.sopts.Context, "response code not set in handler via SetRspCode(ctx, http.StatusXXX)") } w.Write(b) } diff --git a/http.go b/http.go index c7a1b37..c4d2132 100644 --- a/http.go +++ b/http.go @@ -4,7 +4,6 @@ package http import ( "context" "crypto/tls" - "errors" "fmt" "net" "net/http" @@ -26,6 +25,7 @@ import ( type httpServer struct { sync.RWMutex opts server.Options + handlers map[string]server.Handler hd server.Handler exit chan chan error subscribers map[*httpSubscriber][]broker.Subscriber @@ -33,6 +33,8 @@ type httpServer struct { registered bool // register service instance rsvc *register.Service + + errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int) } func (h *httpServer) newCodec(ct string) (codec.Codec, error) { @@ -54,13 +56,30 @@ func (h *httpServer) Init(opts ...server.Option) error { for _, o := range opts { o(&h.opts) } + h.errorHandler = DefaultErrorHandler + if fn, ok := h.opts.Context.Value(errorHandlerKey{}).(func(ctx context.Context, s server.Handler, w http.ResponseWriter, r *http.Request, err error, status int)); ok && fn != nil { + h.errorHandler = fn + } + + h.handlers = make(map[string]server.Handler) h.Unlock() return nil } func (h *httpServer) Handle(handler server.Handler) error { h.Lock() - h.hd = handler + if hdlr, ok := handler.(*httpHandler); ok { + if h.handlers == nil { + h.handlers = make(map[string]server.Handler) + } + if _, ok := hdlr.hd.(http.Handler); ok { + h.hd = handler + } else { + h.handlers[handler.Name()] = handler + } + } else { + h.hd = handler + } h.Unlock() return nil } @@ -86,10 +105,6 @@ func (h *httpServer) NewHandler(handler interface{}, opts ...server.HandlerOptio sopts: h.opts, } - hdlr.errorHandler = DefaultErrorHandler - if fn, ok := options.Context.Value(errorHandlerKey{}).(func(ctx context.Context, s server.Handler, w http.ResponseWriter, r *http.Request, err error, status int)); ok && fn != nil { - hdlr.errorHandler = fn - } tp := reflect.TypeOf(handler) /* @@ -184,8 +199,11 @@ func (h *httpServer) Subscribe(sb server.Subscriber) error { } func (h *httpServer) Register() error { + var eps []*register.Endpoint h.RLock() - eps := h.hd.Endpoints() + for _, hdlr := range h.handlers { + eps = append(eps, hdlr.Endpoints()...) + } rsvc := h.rsvc config := h.opts h.RUnlock() @@ -322,7 +340,6 @@ func (h *httpServer) Deregister() error { func (h *httpServer) Start() error { h.RLock() config := h.opts - hd := h.hd h.RUnlock() // micro: config.Transport.Listen(config.Address) @@ -357,14 +374,16 @@ func (h *httpServer) Start() error { h.opts.Address = ts.Addr().String() h.Unlock() - handler, ok := hd.Handler().(http.Handler) - if !ok { - handler, ok = hd.(http.Handler) + var handler http.Handler + if h.hd == nil { + handler = h + } else if hdlr, ok := h.hd.Handler().(http.Handler); ok { + handler = hdlr } - if !ok { - return errors.New("Server required http.Handler") - } + //if !ok { + // return errors.New("Server required http.Handler") + //} if err := config.Broker.Connect(h.opts.Context); err != nil { return err