diff --git a/requestid.go b/requestid.go index bdba1a0..0f47e01 100644 --- a/requestid.go +++ b/requestid.go @@ -5,12 +5,22 @@ import ( "net/textproto" "go.unistack.org/micro/v3/client" + "go.unistack.org/micro/v3/logger" "go.unistack.org/micro/v3/metadata" "go.unistack.org/micro/v3/server" "go.unistack.org/micro/v3/util/id" ) -var XRequestIDKey struct{} +func init() { + logger.DefaultContextAttrFuncs = append(logger.DefaultContextAttrFuncs, func(ctx context.Context) []interface{} { + if v, ok := ctx.Value(XRequestIDKey{}).(string); ok { + return []interface{}{DefaultMetadataKey, v} + } + return nil + }) +} + +type XRequestIDKey struct{} // DefaultMetadataKey contains metadata key var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") @@ -19,41 +29,31 @@ var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { var xid string var err error - var ook, iok bool - if _, ok := ctx.Value(XRequestIDKey).(string); !ok { + if _, ok := ctx.Value(XRequestIDKey{}).(string); !ok { xid, err = id.New() if err != nil { return ctx, err } - ctx = context.WithValue(ctx, XRequestIDKey, xid) + ctx = context.WithValue(ctx, XRequestIDKey{}, xid) } imd, ok := metadata.FromIncomingContext(ctx) if !ok { imd = metadata.New(1) imd.Set(DefaultMetadataKey, xid) + ctx = metadata.NewIncomingContext(ctx, imd) } else if _, ok = imd.Get(DefaultMetadataKey); !ok { imd.Set(DefaultMetadataKey, xid) - } else { - iok = true } omd, ok := metadata.FromOutgoingContext(ctx) if !ok { omd = metadata.New(1) omd.Set(DefaultMetadataKey, xid) + ctx = metadata.NewOutgoingContext(ctx, imd) } else if _, ok = omd.Get(DefaultMetadataKey); !ok { omd.Set(DefaultMetadataKey, xid) - } else { - ook = true - } - - if !iok { - ctx = metadata.NewIncomingContext(ctx, imd) - } - if !ook { - ctx = metadata.NewOutgoingContext(ctx, omd) } return ctx, nil @@ -124,21 +124,10 @@ func NewServerSubscriberWrapper() server.SubscriberWrapper { return func(fn server.SubscriberFunc) server.SubscriberFunc { return func(ctx context.Context, msg server.Message) error { var err error - imd, ok := metadata.FromIncomingContext(ctx) - if !ok { - imd = metadata.New(1) + if xid, ok := msg.Header()[DefaultMetadataKey]; ok { + ctx = context.WithValue(ctx, XRequestIDKey{}, xid) } - omd, ok := metadata.FromOutgoingContext(ctx) - if !ok { - omd = metadata.New(1) - } - if id, ok := msg.Header()[DefaultMetadataKey]; ok { - imd.Set(DefaultMetadataKey, id) - omd.Set(DefaultMetadataKey, id) - ctx = context.WithValue(ctx, XRequestIDKey, id) - ctx = metadata.NewIncomingContext(ctx, imd) - ctx = metadata.NewOutgoingContext(ctx, omd) - } else if ctx, err = DefaultMetadataFunc(ctx); err != nil { + if ctx, err = DefaultMetadataFunc(ctx); err != nil { return err } return fn(ctx, msg)