diff --git a/handler.go b/handler.go index bfe8e1d..6043365 100644 --- a/handler.go +++ b/handler.go @@ -73,6 +73,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } ctx := context.WithValue(r.Context(), rspCodeKey{}, &rspCodeVal{}) + ctx = context.WithValue(ctx, rspHeaderKey{}, &rspHeaderVal{}) md, ok := metadata.FromIncomingContext(ctx) if !ok { md = metadata.New(len(r.Header) + 8) @@ -264,6 +265,13 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set(k, v) } } + if md := getRspHeader(ctx); md != nil { + for k, v := range md { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + } if nct := w.Header().Get(metadata.HeaderContentType); nct != ct { if cf, err = h.newCodec(nct); err != nil { h.errorHandler(ctx, nil, w, r, err, http.StatusBadRequest) diff --git a/options.go b/options.go index 32fdc7b..2c06b0c 100644 --- a/options.go +++ b/options.go @@ -38,6 +38,20 @@ type ( } ) +type ( + rspHeaderKey struct{} + rspHeaderVal struct { + h http.Header + } +) + +// SetRspHeader add response headers +func SetRspHeader(ctx context.Context, h http.Header) { + if rsp, ok := ctx.Value(rspHeaderKey{}).(*rspHeaderVal); ok { + rsp.h = h + } +} + // SetRspCode saves response code in context, must be used by handler to specify http code func SetRspCode(ctx context.Context, code int) { if rsp, ok := ctx.Value(rspCodeKey{}).(*rspCodeVal); ok { @@ -45,6 +59,14 @@ func SetRspCode(ctx context.Context, code int) { } } +// getRspHeader get http.Header from context +func getRspHeader(ctx context.Context) http.Header { + if rsp, ok := ctx.Value(rspHeaderKey{}).(*rspHeaderVal); ok { + return rsp.h + } + return nil +} + // GetRspCode used internally by generated http server handler func GetRspCode(ctx context.Context) int { var code int