From 316f6440907f96206b10a9f8b29f241e21c25d18 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Mon, 8 May 2023 22:23:34 +0300 Subject: [PATCH] allow to expose some method via http.HandlerFunc Signed-off-by: Vasiliy Tolstov --- handler.go | 297 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 288 insertions(+), 9 deletions(-) diff --git a/handler.go b/handler.go index a779aeb..a00b637 100644 --- a/handler.go +++ b/handler.go @@ -60,6 +60,279 @@ func (h *httpHandler) Options() server.HandlerOptions { return h.opts } +func (h *httpServer) HTTPHandlerFunc(handler interface{}) (http.HandlerFunc, error) { + if handler == nil { + return nil, fmt.Errorf("invalid handler specified: %v", handler) + } + + rtype := reflect.TypeOf(handler) + if rtype.NumIn() != 3 { + return nil, fmt.Errorf("invalid handler, NumIn != 3: %v", rtype.NumIn()) + } + + argType := rtype.In(1) + replyType := rtype.In(2) + + // First arg need not be a pointer. + if !isExportedOrBuiltinType(argType) { + return nil, fmt.Errorf("invalid handler, argument type not exported: %v", argType) + } + + if replyType.Kind() != reflect.Ptr { + return nil, fmt.Errorf("invalid handler, reply type not a pointer: %v", replyType) + } + + // Reply type must be exported. + if !isExportedOrBuiltinType(replyType) { + return nil, fmt.Errorf("invalid handler, reply type not exported: %v", replyType) + } + + if rtype.NumOut() != 1 { + return nil, fmt.Errorf("invalid handler, has wrong number of outs: %v", rtype.NumOut()) + } + + // The return type of the method must be error. + if returnType := rtype.Out(0); returnType != typeOfError { + return nil, fmt.Errorf("invalid handler, returns %v not error", returnType.String()) + } + + return func(w http.ResponseWriter, r *http.Request) { + ct := DefaultContentType + if htype := r.Header.Get(metadata.HeaderContentType); htype != "" { + ct = htype + } + + ctx := context.WithValue(r.Context(), rspCodeKey{}, &rspCodeVal{}) + ctx = context.WithValue(ctx, rspHeaderKey{}, &rspHeaderVal{}) + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(len(r.Header) + 8) + } + for k, v := range r.Header { + md[k] = strings.Join(v, ", ") + } + md["RemoteAddr"] = r.RemoteAddr + md["Method"] = r.Method + md["URL"] = r.URL.String() + md["Proto"] = r.Proto + md["ContentLength"] = fmt.Sprintf("%d", r.ContentLength) + md["TransferEncoding"] = strings.Join(r.TransferEncoding, ",") + md["Host"] = r.Host + md["RequestURI"] = r.RequestURI + ctx = metadata.NewIncomingContext(ctx, md) + + path := r.URL.Path + + if r.Body != nil { + defer r.Body.Close() + } + + matches := make(map[string]interface{}) + var match bool + var hldr *patHandler + var handler *httpHandler + + for _, shdlr := range h.handlers { + hdlr := shdlr.(*httpHandler) + fh, mp, err := hdlr.handlers.Search(r.Method, path) + if err == nil { + match = true + for k, v := range mp { + matches[k] = v + } + hldr = fh.(*patHandler) + handler = hdlr + break + } else if err == rhttp.ErrMethodNotAllowed && !h.registerRPC { + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = w.Write([]byte("not matching route found")) + return + } + } + + if !match && h.registerRPC { + microMethod, mok := md.Get(metadata.HeaderEndpoint) + if mok { + serviceMethod := strings.Split(microMethod, ".") + if len(serviceMethod) == 2 { + if shdlr, ok := h.handlers[serviceMethod[0]]; ok { + hdlr := shdlr.(*httpHandler) + fh, mp, err := hdlr.handlers.Search(http.MethodPost, "/"+microMethod) + if err == nil { + match = true + for k, v := range mp { + matches[k] = v + } + hldr = fh.(*patHandler) + handler = hdlr + } + } + } + } + } + + // get fields from url values + if len(r.URL.RawQuery) > 0 { + umd, cerr := rflutil.URLMap(r.URL.RawQuery) + if cerr != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(cerr.Error())) + return + } + for k, v := range umd { + matches[k] = v + } + } + + cf, err := h.newCodec(ct) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + + var argv, replyv reflect.Value + + // Decode the argument value. + argIsValue := false // if true, need to indirect before calling. + if hldr.mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(hldr.mtype.ArgType.Elem()) + } else { + argv = reflect.New(hldr.mtype.ArgType) + argIsValue = true + } + + if argIsValue { + argv = argv.Elem() + } + + // reply value + replyv = reflect.New(hldr.mtype.ReplyType.Elem()) + + function := hldr.mtype.method.Func + var returnValues []reflect.Value + + if r.Body != nil { + var buf []byte + buf, err = io.ReadAll(r.Body) + if err != nil && err != io.EOF { + h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError) + return + } + + if err = cf.Unmarshal(buf, argv.Interface()); err != nil { + h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) + return + } + } + + matches = rflutil.FlattenMap(matches) + if err = rflutil.Merge(argv.Interface(), matches, rflutil.SliceAppend(true), rflutil.Tags([]string{"protobuf", "json"})); err != nil { + h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) + return + } + + hr := &rpcRequest{ + codec: cf, + service: handler.sopts.Name, + contentType: ct, + method: fmt.Sprintf("%s.%s", hldr.name, hldr.mtype.method.Name), + endpoint: fmt.Sprintf("%s.%s", hldr.name, hldr.mtype.method.Name), + payload: argv.Interface(), + header: md, + } + + // define the handler func + fn := func(fctx context.Context, req server.Request, rsp interface{}) (err error) { + returnValues = function.Call([]reflect.Value{hldr.rcvr, hldr.mtype.prepareContext(fctx), argv, reflect.ValueOf(rsp)}) + + // The return value for the method is an error. + if rerr := returnValues[0].Interface(); rerr != nil { + err = rerr.(error) + } + + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + md = metadata.New(0) + } + if nmd, ok := metadata.FromOutgoingContext(fctx); ok { + for k, v := range nmd { + md.Set(k, v) + } + } + metadata.SetOutgoingContext(ctx, md) + + return err + } + + // wrap the handler func + for i := len(handler.sopts.HdlrWrappers); i > 0; i-- { + fn = handler.sopts.HdlrWrappers[i-1](fn) + } + + if ct == "application/x-www-form-urlencoded" { + cf, err = h.newCodec(DefaultContentType) + if err != nil { + h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError) + return + } + ct = DefaultContentType + } + + scode := int(200) + appErr := fn(ctx, hr, replyv.Interface()) + + w.Header().Set(metadata.HeaderContentType, ct) + if md, ok := metadata.FromOutgoingContext(ctx); ok { + for k, v := range md { + w.Header().Set(k, v) + } + } + if md := getRspHeader(ctx); md != nil { + for k, v := range md { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + } + if nct := w.Header().Get(metadata.HeaderContentType); nct != ct { + if cf, err = h.newCodec(nct); err != nil { + h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest) + return + } + } + + var buf []byte + if appErr != nil { + switch verr := appErr.(type) { + case *errors.Error: + scode = int(verr.Code) + buf, err = cf.Marshal(verr) + case *Error: + buf, err = cf.Marshal(verr.err) + default: + buf, err = cf.Marshal(appErr) + } + } else { + buf, err = cf.Marshal(replyv.Interface()) + } + + if err != nil && handler.sopts.Logger.V(logger.ErrorLevel) { + handler.sopts.Logger.Errorf(handler.sopts.Context, "handler err: %v", err) + return + } + + if nscode := GetRspCode(ctx); nscode != 0 { + scode = nscode + } + w.WriteHeader(scode) + + if _, cerr := w.Write(buf); cerr != nil { + handler.sopts.Logger.Errorf(ctx, "write failed: %v", cerr) + } + }, nil +} + func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // check for http.HandlerFunc handlers if ph, _, err := h.pathHandlers.Search(r.Method, r.URL.Path); err == nil { @@ -91,7 +364,9 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { md["RequestURI"] = r.RequestURI ctx = metadata.NewIncomingContext(ctx, md) - defer r.Body.Close() + if r.Body != nil { + defer r.Body.Close() + } path := r.URL.Path if !strings.HasPrefix(path, "/") { @@ -192,15 +467,18 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { function := hldr.mtype.method.Func var returnValues []reflect.Value - buf, err := io.ReadAll(r.Body) - if err != nil && err != io.EOF { - h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError) - return - } + if r.Body != nil { + var buf []byte + buf, err = io.ReadAll(r.Body) + if err != nil && err != io.EOF { + h.errorHandler(ctx, handler, w, r, err, http.StatusInternalServerError) + return + } - if err = cf.Unmarshal(buf, argv.Interface()); err != nil { - h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) - return + if err = cf.Unmarshal(buf, argv.Interface()); err != nil { + h.errorHandler(ctx, handler, w, r, err, http.StatusBadRequest) + return + } } matches = rflutil.FlattenMap(matches) @@ -279,6 +557,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + var buf []byte if appErr != nil { switch verr := appErr.(type) { case *errors.Error: