diff --git a/handler.go b/handler.go index 90aab6e..d46ef67 100644 --- a/handler.go +++ b/handler.go @@ -243,7 +243,22 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } scode := int(200) - if appErr := fn(ctx, hr, replyv.Interface()); appErr != nil { + appErr := fn(ctx, hr, replyv.Interface()) + + w.Header().Set("Content-Type", ct) + if md, ok := metadata.FromOutgoingContext(ctx); ok { + for k, v := range md { + w.Header().Set(k, v) + } + } + if ct != w.Header().Get("Content-Type") { + if cf, err = h.newCodec(ct); err != nil { + h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest) + return + } + } + + if appErr != nil { switch verr := appErr.(type) { case *errors.Error: scode = int(verr.Code) @@ -256,19 +271,12 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { b, err = cf.Marshal(replyv.Interface()) } + if err != nil && handler.sopts.Logger.V(logger.ErrorLevel) { handler.sopts.Logger.Errorf(handler.sopts.Context, "handler err: %v", err) return } - w.Header().Set("Content-Type", ct) - - if md, ok := metadata.FromOutgoingContext(ctx); ok { - for k, v := range md { - w.Header().Set(k, v) - } - } - if nscode := GetRspCode(ctx); nscode != 0 { scode = nscode }