fix for multiple server registration

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2021-03-09 14:19:59 +03:00
parent d3c5a503c6
commit a46c9d395a
2 changed files with 68 additions and 43 deletions

View File

@ -34,13 +34,13 @@ type patHandler struct {
} }
type httpHandler struct { type httpHandler struct {
name string name string
opts server.HandlerOptions opts server.HandlerOptions
sopts server.Options sopts server.Options
eps []*register.Endpoint eps []*register.Endpoint
hd interface{} hd interface{}
handlers map[string][]patHandler handlers map[string][]patHandler
errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int) //errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int)
} }
func (h *httpHandler) newCodec(ct string) (codec.Codec, error) { func (h *httpHandler) newCodec(ct string) (codec.Codec, error) {
@ -66,14 +66,14 @@ func (h *httpHandler) Options() server.HandlerOptions {
return h.opts return h.opts
} }
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := metadata.NewContext(r.Context(), nil) ctx := metadata.NewContext(r.Context(), nil)
defer r.Body.Close() defer r.Body.Close()
path := r.URL.Path path := r.URL.Path
if !strings.HasPrefix(path, "/") { if !strings.HasPrefix(path, "/") {
h.errorHandler(ctx, h, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest) h.errorHandler(ctx, nil, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest)
} }
ct := DefaultContentType ct := DefaultContentType
@ -83,7 +83,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cf, err := h.newCodec(ct) cf, err := h.newCodec(ct)
if err != nil { if err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest)
} }
components := strings.Split(path[1:], "/") components := strings.Split(path[1:], "/")
@ -91,7 +91,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var verb string var verb string
idx := strings.LastIndex(components[l-1], ":") idx := strings.LastIndex(components[l-1], ":")
if idx == 0 { if idx == 0 {
h.errorHandler(ctx, h, w, r, fmt.Errorf("not found"), http.StatusNotFound) h.errorHandler(ctx, nil, w, r, fmt.Errorf("not found"), http.StatusNotFound)
return return
} }
if idx > 0 { if idx > 0 {
@ -102,19 +102,25 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
matches := make(map[string]interface{}) matches := make(map[string]interface{})
var match bool var match bool
var hldr patHandler var hldr patHandler
for _, hldr = range h.handlers[r.Method] { var handler *httpHandler
mp, err := hldr.pat.Match(components, verb) for _, hpat := range h.handlers {
if err == nil { handlertmp := hpat.(*httpHandler)
match = true for _, hldrtmp := range handlertmp.handlers[r.Method] {
for k, v := range mp { mp, err := hldrtmp.pat.Match(components, verb)
matches[k] = v if err == nil {
match = true
for k, v := range mp {
matches[k] = v
}
hldr = hldrtmp
handler = handlertmp
break
} }
break
} }
} }
if !match { if !match {
h.errorHandler(ctx, h, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound) h.errorHandler(ctx, nil, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound)
return return
} }
@ -131,7 +137,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(r.URL.RawQuery) > 0 { if len(r.URL.RawQuery) > 0 {
umd, err := rflutil.URLMap(r.URL.RawQuery) umd, err := rflutil.URLMap(r.URL.RawQuery)
if err != nil { if err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest)
} }
for k, v := range umd { for k, v := range umd {
matches[k] = v matches[k] = v
@ -161,24 +167,24 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var returnValues []reflect.Value var returnValues []reflect.Value
if err = cf.ReadBody(r.Body, argv.Interface()); err != nil && err != io.EOF { if err = cf.ReadBody(r.Body, argv.Interface()); err != nil && err != io.EOF {
h.errorHandler(ctx, h, w, r, err, http.StatusInternalServerError) h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError)
} }
matches = rflutil.FlattenMap(matches) matches = rflutil.FlattenMap(matches)
if err = rflutil.MergeMap(argv.Interface(), matches); err != nil { if err = rflutil.MergeMap(argv.Interface(), matches); err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest)
return return
} }
b, err := cf.Marshal(argv.Interface()) b, err := cf.Marshal(argv.Interface())
if err != nil { if err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest) h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest)
return return
} }
hr := &rpcRequest{ hr := &rpcRequest{
codec: cf, codec: cf,
service: h.sopts.Name, service: handler.sopts.Name,
contentType: ct, contentType: ct,
method: fmt.Sprintf("%s.%s", hldr.name, hldr.mtype.method.Name), method: fmt.Sprintf("%s.%s", hldr.name, hldr.mtype.method.Name),
body: b, body: b,
@ -206,8 +212,8 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// wrap the handler func // wrap the handler func
for i := len(h.sopts.HdlrWrappers); i > 0; i-- { for i := len(handler.sopts.HdlrWrappers); i > 0; i-- {
fn = h.sopts.HdlrWrappers[i-1](fn) fn = handler.sopts.HdlrWrappers[i-1](fn)
} }
if appErr := fn(ctx, hr, replyv.Interface()); appErr != nil { if appErr := fn(ctx, hr, replyv.Interface()); appErr != nil {
@ -223,8 +229,8 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else { } else {
b, err = cf.Marshal(replyv.Interface()) b, err = cf.Marshal(replyv.Interface())
} }
if err != nil && h.sopts.Logger.V(logger.ErrorLevel) { if err != nil && handler.sopts.Logger.V(logger.ErrorLevel) {
h.sopts.Logger.Errorf(h.sopts.Context, "handler err: %v", err) handler.sopts.Logger.Errorf(handler.sopts.Context, "handler err: %v", err)
return return
} }
@ -238,7 +244,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if scode != 0 { if scode != 0 {
w.WriteHeader(scode) w.WriteHeader(scode)
} else { } else {
h.sopts.Logger.Warn(h.sopts.Context, "response code not set in handler via SetRspCode(ctx, http.StatusXXX)") handler.sopts.Logger.Warn(handler.sopts.Context, "response code not set in handler via SetRspCode(ctx, http.StatusXXX)")
} }
w.Write(b) w.Write(b)
} }

47
http.go
View File

@ -4,7 +4,6 @@ package http
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -26,6 +25,7 @@ import (
type httpServer struct { type httpServer struct {
sync.RWMutex sync.RWMutex
opts server.Options opts server.Options
handlers map[string]server.Handler
hd server.Handler hd server.Handler
exit chan chan error exit chan chan error
subscribers map[*httpSubscriber][]broker.Subscriber subscribers map[*httpSubscriber][]broker.Subscriber
@ -33,6 +33,8 @@ type httpServer struct {
registered bool registered bool
// register service instance // register service instance
rsvc *register.Service rsvc *register.Service
errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int)
} }
func (h *httpServer) newCodec(ct string) (codec.Codec, error) { func (h *httpServer) newCodec(ct string) (codec.Codec, error) {
@ -54,13 +56,30 @@ func (h *httpServer) Init(opts ...server.Option) error {
for _, o := range opts { for _, o := range opts {
o(&h.opts) o(&h.opts)
} }
h.errorHandler = DefaultErrorHandler
if fn, ok := h.opts.Context.Value(errorHandlerKey{}).(func(ctx context.Context, s server.Handler, w http.ResponseWriter, r *http.Request, err error, status int)); ok && fn != nil {
h.errorHandler = fn
}
h.handlers = make(map[string]server.Handler)
h.Unlock() h.Unlock()
return nil return nil
} }
func (h *httpServer) Handle(handler server.Handler) error { func (h *httpServer) Handle(handler server.Handler) error {
h.Lock() h.Lock()
h.hd = handler if hdlr, ok := handler.(*httpHandler); ok {
if h.handlers == nil {
h.handlers = make(map[string]server.Handler)
}
if _, ok := hdlr.hd.(http.Handler); ok {
h.hd = handler
} else {
h.handlers[handler.Name()] = handler
}
} else {
h.hd = handler
}
h.Unlock() h.Unlock()
return nil return nil
} }
@ -86,10 +105,6 @@ func (h *httpServer) NewHandler(handler interface{}, opts ...server.HandlerOptio
sopts: h.opts, sopts: h.opts,
} }
hdlr.errorHandler = DefaultErrorHandler
if fn, ok := options.Context.Value(errorHandlerKey{}).(func(ctx context.Context, s server.Handler, w http.ResponseWriter, r *http.Request, err error, status int)); ok && fn != nil {
hdlr.errorHandler = fn
}
tp := reflect.TypeOf(handler) tp := reflect.TypeOf(handler)
/* /*
@ -184,8 +199,11 @@ func (h *httpServer) Subscribe(sb server.Subscriber) error {
} }
func (h *httpServer) Register() error { func (h *httpServer) Register() error {
var eps []*register.Endpoint
h.RLock() h.RLock()
eps := h.hd.Endpoints() for _, hdlr := range h.handlers {
eps = append(eps, hdlr.Endpoints()...)
}
rsvc := h.rsvc rsvc := h.rsvc
config := h.opts config := h.opts
h.RUnlock() h.RUnlock()
@ -322,7 +340,6 @@ func (h *httpServer) Deregister() error {
func (h *httpServer) Start() error { func (h *httpServer) Start() error {
h.RLock() h.RLock()
config := h.opts config := h.opts
hd := h.hd
h.RUnlock() h.RUnlock()
// micro: config.Transport.Listen(config.Address) // micro: config.Transport.Listen(config.Address)
@ -357,14 +374,16 @@ func (h *httpServer) Start() error {
h.opts.Address = ts.Addr().String() h.opts.Address = ts.Addr().String()
h.Unlock() h.Unlock()
handler, ok := hd.Handler().(http.Handler) var handler http.Handler
if !ok { if h.hd == nil {
handler, ok = hd.(http.Handler) handler = h
} else if hdlr, ok := h.hd.Handler().(http.Handler); ok {
handler = hdlr
} }
if !ok { //if !ok {
return errors.New("Server required http.Handler") // return errors.New("Server required http.Handler")
} //}
if err := config.Broker.Connect(h.opts.Context); err != nil { if err := config.Broker.Connect(h.opts.Context); err != nil {
return err return err