Compare commits

...

6 Commits

2 changed files with 32 additions and 15 deletions

View File

@@ -334,12 +334,6 @@ func (h *Server) HTTPHandlerFunc(handler interface{}) (http.HandlerFunc, error)
} }
func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// check for http.HandlerFunc handlers
if ph, _, err := h.pathHandlers.Search(r.Method, r.URL.Path); err == nil {
ph.(http.HandlerFunc)(w, r)
return
}
ct := DefaultContentType ct := DefaultContentType
if htype := r.Header.Get(metadata.HeaderContentType); htype != "" { if htype := r.Header.Get(metadata.HeaderContentType); htype != "" {
ct = htype ct = htype
@@ -355,19 +349,22 @@ func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
md[k] = strings.Join(v, ", ") md[k] = strings.Join(v, ", ")
} }
md["RemoteAddr"] = r.RemoteAddr md["RemoteAddr"] = r.RemoteAddr
if r.TLS != nil {
md["Scheme"] = "https"
} else {
md["Scheme"] = "http"
}
md["Method"] = r.Method md["Method"] = r.Method
md["URL"] = r.URL.String() md["URL"] = r.URL.String()
md["Proto"] = r.Proto md["Proto"] = r.Proto
md["ContentLength"] = fmt.Sprintf("%d", r.ContentLength) md["ContentLength"] = fmt.Sprintf("%d", r.ContentLength)
md["TransferEncoding"] = strings.Join(r.TransferEncoding, ",") if len(r.TransferEncoding) > 0 {
md["TransferEncoding"] = strings.Join(r.TransferEncoding, ",")
}
md["Host"] = r.Host md["Host"] = r.Host
md["RequestURI"] = r.RequestURI md["RequestURI"] = r.RequestURI
ctx = metadata.NewIncomingContext(ctx, md) ctx = metadata.NewIncomingContext(ctx, md)
if r.Body != nil {
defer r.Body.Close()
}
path := r.URL.Path path := r.URL.Path
if !strings.HasPrefix(path, "/") { if !strings.HasPrefix(path, "/") {
h.errorHandler(ctx, nil, w, r, fmt.Errorf("path must starts with /"), http.StatusBadRequest) h.errorHandler(ctx, nil, w, r, fmt.Errorf("path must starts with /"), http.StatusBadRequest)
@@ -424,6 +421,11 @@ func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
} else if !match { } else if !match {
// check for http.HandlerFunc handlers
if ph, _, err := h.pathHandlers.Search(r.Method, r.URL.Path); err == nil {
ph.(http.HandlerFunc)(w, r)
return
}
h.errorHandler(ctx, nil, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound) h.errorHandler(ctx, nil, w, r, fmt.Errorf("not matching route found"), http.StatusNotFound)
return return
} }
@@ -440,6 +442,10 @@ func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
if r.Body != nil {
defer r.Body.Close()
}
cf, err := h.newCodec(ct) cf, err := h.newCodec(ct)
if err != nil { if err != nil {
h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest) h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest)

19
http.go
View File

@@ -378,6 +378,17 @@ func (h *Server) Register() error {
} }
h.Lock() h.Lock()
h.registered = true
h.rsvc = service
h.Unlock()
return nil
}
func (h *Server) subscribe() error {
config := h.opts
for sb := range h.subscribers { for sb := range h.subscribers {
handler := h.createSubHandler(sb, config) handler := h.createSubHandler(sb, config)
var opts []broker.SubscribeOption var opts []broker.SubscribeOption
@@ -401,10 +412,6 @@ func (h *Server) Register() error {
h.subscribers[sb] = []broker.Subscriber{sub} h.subscribers[sb] = []broker.Subscriber{sub}
} }
h.registered = true
h.rsvc = service
h.Unlock()
return nil return nil
} }
@@ -539,6 +546,10 @@ func (h *Server) Start() error {
} }
} }
if err := h.subscribe(); err != nil {
return err
}
fn := handler fn := handler
if h.opts.Context != nil { if h.opts.Context != nil {