diff --git a/handler.go b/handler.go index d1bcca1..431a20d 100644 --- a/handler.go +++ b/handler.go @@ -40,7 +40,6 @@ type httpHandler struct { eps []*register.Endpoint hd interface{} handlers map[string][]patHandler - //errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int) } func (h *httpHandler) newCodec(ct string) (codec.Codec, error) { @@ -74,6 +73,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path if !strings.HasPrefix(path, "/") { h.errorHandler(ctx, nil, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest) + return } ct := DefaultContentType @@ -84,6 +84,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { cf, err := h.newCodec(ct) if err != nil { h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest) + return } components := strings.Split(path[1:], "/") @@ -138,6 +139,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { umd, err := rflutil.URLMap(r.URL.RawQuery) if err != nil { h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) + return } for k, v := range umd { matches[k] = v @@ -168,6 +170,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err = cf.ReadBody(r.Body, argv.Interface()); err != nil && err != io.EOF { h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError) + return } matches = rflutil.FlattenMap(matches) diff --git a/http.go b/http.go index c4d2132..254776b 100644 --- a/http.go +++ b/http.go @@ -375,15 +375,32 @@ func (h *httpServer) Start() error { h.Unlock() var handler http.Handler - if h.hd == nil { - handler = h - } else if hdlr, ok := h.hd.Handler().(http.Handler); ok { - handler = hdlr + var srvFunc func(net.Listener) error + + if h.opts.Context != nil { + if hs, ok := h.opts.Context.Value(serverKey{}).(*http.Server); ok && hs != nil { + if hs.Handler == nil && h.hd != nil { + if hdlr, ok := h.hd.Handler().(http.Handler); ok { + hs.Handler = hdlr + handler = hs.Handler + } + } else { + handler = hs.Handler + } + } } - //if !ok { - // return errors.New("Server required http.Handler") - //} + if handler == nil && h.hd == nil { + handler = h + } else if handler == nil && h.hd != nil { + if hdlr, ok := h.hd.Handler().(http.Handler); ok { + handler = hdlr + } + } + + if handler == nil { + return fmt.Errorf("cant process with nil handler") + } if err := config.Broker.Connect(h.opts.Context); err != nil { return err @@ -400,7 +417,6 @@ func (h *httpServer) Start() error { } fn := handler - var srvFunc func(net.Listener) error if h.opts.Context != nil { if mwf, ok := h.opts.Context.Value(middlewareKey{}).([]func(http.Handler) http.Handler); ok && len(mwf) > 0 {