diff --git a/handler.go b/handler.go index 3a9ede4..62ca13c 100644 --- a/handler.go +++ b/handler.go @@ -66,6 +66,14 @@ func (h *httpHandler) Options() server.HandlerOptions { } func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + for exp, ph := range h.pathHandlers { + if exp.MatchString(r.URL.String()) { + ph(w, r) + return + } + } + ctx := metadata.NewContext(r.Context(), nil) defer r.Body.Close() @@ -76,9 +84,9 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ct := strings.Split(DefaultContentType, ";")[0] + ct := DefaultContentType if htype := r.Header.Get("Content-Type"); htype != "" { - ct = strings.Split(htype, ";")[0] + ct = htype } var cf codec.Codec @@ -87,7 +95,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "application/x-www-form-urlencoded": cf, err = h.newCodec(strings.Split(DefaultContentType, ";")[0]) default: - cf, err = h.newCodec(ct) + cf, err = h.newCodec(strings.Split(ct, ";")[0]) } if err != nil { diff --git a/http.go b/http.go index 69f27af..ac181d9 100644 --- a/http.go +++ b/http.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "reflect" + "regexp" "sort" "strings" "sync" @@ -35,6 +36,7 @@ type httpServer struct { rsvc *register.Service init bool errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int) + pathHandlers map[*regexp.Regexp]http.HandlerFunc } func (h *httpServer) newCodec(ct string) (codec.Codec, error) { @@ -57,6 +59,8 @@ func (h *httpServer) Init(opts ...server.Option) error { } h.Lock() + defer h.Unlock() + for _, o := range opts { o(&h.opts) } @@ -66,8 +70,18 @@ func (h *httpServer) Init(opts ...server.Option) error { if h.handlers == nil { h.handlers = make(map[string]server.Handler) } - h.Unlock() - + if h.pathHandlers == nil { + h.pathHandlers = make(map[*regexp.Regexp]http.HandlerFunc) + } + if phs, ok := h.opts.Context.Value(pathHandlerKey{}).(*pathHandlerVal); ok && phs.h != nil { + for pp, ph := range phs.h { + exp, err := regexp.Compile(pp) + if err != nil { + return err + } + h.pathHandlers[exp] = ph + } + } if err := h.opts.Register.Init(); err != nil { return err } @@ -554,5 +568,6 @@ func NewServer(opts ...server.Option) server.Server { exit: make(chan chan error), subscribers: make(map[*httpSubscriber][]broker.Subscriber), errorHandler: DefaultErrorHandler, + pathHandlers: make(map[*regexp.Regexp]http.HandlerFunc), } } diff --git a/options.go b/options.go index a8e24af..b116a23 100644 --- a/options.go +++ b/options.go @@ -65,5 +65,22 @@ func ErrorHandler(fn func(ctx context.Context, s server.Handler, w http.Response return server.SetOption(errorHandlerKey{}, fn) } -// type pathHandlerKey struct{} -// PathHandler specifies http handler for path +type pathHandlerKey struct{} +type pathHandlerVal struct { + h map[string]http.HandlerFunc +} + +// PathHandler specifies http handler for path regexp +func PathHandler(path string, h http.HandlerFunc) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + v, ok := o.Context.Value(pathHandlerKey{}).(*pathHandlerVal) + if !ok { + v = &pathHandlerVal{h: make(map[string]http.HandlerFunc)} + } + v.h[path] = h + o.Context = context.WithValue(o.Context, pathHandlerKey{}, v) + } +}