diff --git a/requestid.go b/requestid.go index 6a98d02..3609c03 100644 --- a/requestid.go +++ b/requestid.go @@ -10,24 +10,36 @@ import ( "go.unistack.org/micro/v3/util/id" ) -// MetadataKey contains metadata key -var MetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") +// DefaultMetadataKey contains metadata key +var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") -// MetadataFunc wil be used if user not provide own func to fill metadata -var MetadataFunc = func(ctx context.Context) (context.Context, error) { - md, ok := metadata.FromIncomingContext(ctx) +// DefaultMetadataFunc wil be used if user not provide own func to fill metadata +var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { + imd, ok := metadata.FromIncomingContext(ctx) if !ok { - md = metadata.New(1) + imd = metadata.New(1) } - if _, ok = md.Get(MetadataKey); ok { - return ctx, nil + omd, ok := metadata.FromOutgoingContext(ctx) + if !ok { + omd = metadata.New(1) } - uid, err := id.New() - if err != nil { - return ctx, err + v, iok := imd.Get(DefaultMetadataKey) + if iok { + if _, ook := omd.Get(DefaultMetadataKey); ook { + return ctx, nil + } } - md.Set(MetadataKey, uid) - ctx = metadata.NewIncomingContext(ctx, md) + if !iok { + uid, err := id.New() + if err != nil { + return ctx, err + } + v = uid + } + imd.Set(DefaultMetadataKey, v) + omd.Set(DefaultMetadataKey, v) + ctx = metadata.NewIncomingContext(ctx, imd) + ctx = metadata.NewOutgoingContext(ctx, omd) return ctx, nil } @@ -48,7 +60,7 @@ func NewClientCallWrapper() client.CallWrapper { return func(fn client.CallFunc) client.CallFunc { return func(ctx context.Context, addr string, req client.Request, rsp interface{}, opts client.CallOptions) error { var err error - if ctx, err = MetadataFunc(ctx); err != nil { + if ctx, err = DefaultMetadataFunc(ctx); err != nil { return err } return fn(ctx, addr, req, rsp, opts) @@ -58,7 +70,7 @@ func NewClientCallWrapper() client.CallWrapper { func (w *wrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { var err error - if ctx, err = MetadataFunc(ctx); err != nil { + if ctx, err = DefaultMetadataFunc(ctx); err != nil { return err } return w.Client.Call(ctx, req, rsp, opts...) @@ -66,7 +78,7 @@ func (w *wrapper) Call(ctx context.Context, req client.Request, rsp interface{}, func (w *wrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { var err error - if ctx, err = MetadataFunc(ctx); err != nil { + if ctx, err = DefaultMetadataFunc(ctx); err != nil { return nil, err } return w.Client.Stream(ctx, req, opts...) @@ -74,7 +86,7 @@ func (w *wrapper) Stream(ctx context.Context, req client.Request, opts ...client func (w *wrapper) Publish(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { var err error - if ctx, err = MetadataFunc(ctx); err != nil { + if ctx, err = DefaultMetadataFunc(ctx); err != nil { return err } return w.Client.Publish(ctx, msg, opts...) @@ -84,7 +96,7 @@ func NewServerHandlerWrapper() server.HandlerWrapper { return func(fn server.HandlerFunc) server.HandlerFunc { return func(ctx context.Context, req server.Request, rsp interface{}) error { var err error - if ctx, err = MetadataFunc(ctx); err != nil { + if ctx, err = DefaultMetadataFunc(ctx); err != nil { return err } return fn(ctx, req, rsp) @@ -96,14 +108,20 @@ func NewServerSubscriberWrapper() server.SubscriberWrapper { return func(fn server.SubscriberFunc) server.SubscriberFunc { return func(ctx context.Context, msg server.Message) error { var err error - md, ok := metadata.FromIncomingContext(ctx) + imd, ok := metadata.FromIncomingContext(ctx) if !ok { - md = metadata.New(1) + imd = metadata.New(1) } - if id, ok := msg.Header()[MetadataKey]; ok { - md.Set(MetadataKey, id) - ctx = metadata.NewIncomingContext(ctx, md) - } else if ctx, err = MetadataFunc(ctx); err != nil { + omd, ok := metadata.FromOutgoingContext(ctx) + if !ok { + omd = metadata.New(1) + } + if id, ok := msg.Header()[DefaultMetadataKey]; ok { + imd.Set(DefaultMetadataKey, id) + omd.Set(DefaultMetadataKey, id) + ctx = metadata.NewIncomingContext(ctx, imd) + ctx = metadata.NewOutgoingContext(ctx, omd) + } else if ctx, err = DefaultMetadataFunc(ctx); err != nil { return err } return fn(ctx, msg) diff --git a/requestid_test.go b/requestid_test.go new file mode 100644 index 0000000..ce7fad7 --- /dev/null +++ b/requestid_test.go @@ -0,0 +1,33 @@ +package requestid + +import ( + "context" + "testing" + + "go.unistack.org/micro/v3/metadata" +) + +func TestDefaultMetadataFunc(t *testing.T) { + ctx := context.TODO() + + nctx, err := DefaultMetadataFunc(ctx) + if err != nil { + t.Fatalf("%v", err) + } + + imd, ok := metadata.FromIncomingContext(nctx) + if !ok { + t.Fatalf("md missing in incoming context") + } + omd, ok := metadata.FromOutgoingContext(nctx) + if !ok { + t.Fatalf("md missing in outgoing context") + } + + _, iok := imd.Get(DefaultMetadataKey) + _, ook := omd.Get(DefaultMetadataKey) + + if !iok || !ook { + t.Fatalf("missing metadata key value") + } +}