fix for multiple server registration

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
2021-03-09 14:19:59 +03:00
parent d3c5a503c6
commit a46c9d395a
2 changed files with 68 additions and 43 deletions

View File

@@ -34,13 +34,13 @@ type patHandler struct {
}
type httpHandler struct {
name string
opts server.HandlerOptions
sopts server.Options
eps []*register.Endpoint
hd interface{}
handlers map[string][]patHandler
errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int)
name string
opts server.HandlerOptions
sopts server.Options
eps []*register.Endpoint
hd interface{}
handlers map[string][]patHandler
//errorHandler func(context.Context, server.Handler, http.ResponseWriter, *http.Request, error, int)
}
func (h *httpHandler) newCodec(ct string) (codec.Codec, error) {
@@ -66,14 +66,14 @@ func (h *httpHandler) Options() server.HandlerOptions {
return h.opts
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := metadata.NewContext(r.Context(), nil)
defer r.Body.Close()
path := r.URL.Path
if !strings.HasPrefix(path, "/") {
h.errorHandler(ctx, h, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest)
h.errorHandler(ctx, nil, w, r, fmt.Errorf("path must contains /"), http.StatusBadRequest)
}
ct := DefaultContentType
@@ -83,7 +83,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cf, err := h.newCodec(ct)
if err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest)
h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest)
}
components := strings.Split(path[1:], "/")
@@ -91,7 +91,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var verb string
idx := strings.LastIndex(components[l-1], ":")
if idx == 0 {
h.errorHandler(ctx, h, w, r, fmt.Errorf("not found"), http.StatusNotFound)
h.errorHandler(ctx, nil, w, r, fmt.Errorf("not found"), http.StatusNotFound)
return
}
if idx > 0 {
@@ -102,19 +102,25 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
matches := make(map[string]interface{})
var match bool
var hldr patHandler
for _, hldr = range h.handlers[r.Method] {
mp, err := hldr.pat.Match(components, verb)
if err == nil {
match = true
for k, v := range mp {
matches[k] = v
var handler *httpHandler
for _, hpat := range h.handlers {
handlertmp := hpat.(*httpHandler)
for _, hldrtmp := range handlertmp.handlers[r.Method] {
mp, err := hldrtmp.pat.Match(components, verb)
if err == nil {
match = true
for k, v := range mp {
matches[k] = v
}
hldr = hldrtmp
handler = handlertmp
break
}
break
}
}
if !match {
h.errorHandler(ctx, h, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound)
h.errorHandler(ctx, nil, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound)
return
}
@@ -131,7 +137,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(r.URL.RawQuery) > 0 {
umd, err := rflutil.URLMap(r.URL.RawQuery)
if err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest)
h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest)
}
for k, v := range umd {
matches[k] = v
@@ -161,24 +167,24 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var returnValues []reflect.Value
if err = cf.ReadBody(r.Body, argv.Interface()); err != nil && err != io.EOF {
h.errorHandler(ctx, h, w, r, err, http.StatusInternalServerError)
h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError)
}
matches = rflutil.FlattenMap(matches)
if err = rflutil.MergeMap(argv.Interface(), matches); err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest)
h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest)
return
}
b, err := cf.Marshal(argv.Interface())
if err != nil {
h.errorHandler(ctx, h, w, r, err, http.StatusBadRequest)
h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest)
return
}
hr := &rpcRequest{
codec: cf,
service: h.sopts.Name,
service: handler.sopts.Name,
contentType: ct,
method: fmt.Sprintf("%s.%s", hldr.name, hldr.mtype.method.Name),
body: b,
@@ -206,8 +212,8 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// wrap the handler func
for i := len(h.sopts.HdlrWrappers); i > 0; i-- {
fn = h.sopts.HdlrWrappers[i-1](fn)
for i := len(handler.sopts.HdlrWrappers); i > 0; i-- {
fn = handler.sopts.HdlrWrappers[i-1](fn)
}
if appErr := fn(ctx, hr, replyv.Interface()); appErr != nil {
@@ -223,8 +229,8 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else {
b, err = cf.Marshal(replyv.Interface())
}
if err != nil && h.sopts.Logger.V(logger.ErrorLevel) {
h.sopts.Logger.Errorf(h.sopts.Context, "handler err: %v", err)
if err != nil && handler.sopts.Logger.V(logger.ErrorLevel) {
handler.sopts.Logger.Errorf(handler.sopts.Context, "handler err: %v", err)
return
}
@@ -238,7 +244,7 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if scode != 0 {
w.WriteHeader(scode)
} else {
h.sopts.Logger.Warn(h.sopts.Context, "response code not set in handler via SetRspCode(ctx, http.StatusXXX)")
handler.sopts.Logger.Warn(handler.sopts.Context, "response code not set in handler via SetRspCode(ctx, http.StatusXXX)")
}
w.Write(b)
}