diff --git a/requestid.go b/requestid.go index 50db766..a9d2ec2 100644 --- a/requestid.go +++ b/requestid.go @@ -1,4 +1,4 @@ -package requestid // import "go.unistack.org/micro-wrapper-requestid/v4" +package requestid import ( "context" @@ -14,14 +14,14 @@ import ( func init() { logger.DefaultContextAttrFuncs = append(logger.DefaultContextAttrFuncs, func(ctx context.Context) []interface{} { - if v, ok := ctx.Value(XRequestIDKey).(string); ok { + if v, ok := ctx.Value(XRequestIDKey{}).(string); ok { return []interface{}{DefaultMetadataKey, v} } return nil }) } -var XRequestIDKey struct{} +type XRequestIDKey struct{} // DefaultMetadataKey contains metadata key x-request-id var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") @@ -30,41 +30,31 @@ var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { var xid string var err error - var ook, iok bool - if _, ok := ctx.Value(XRequestIDKey).(string); !ok { + if _, ok := ctx.Value(XRequestIDKey{}).(string); !ok { xid, err = id.New() if err != nil { return ctx, err } - ctx = context.WithValue(ctx, XRequestIDKey, xid) + ctx = context.WithValue(ctx, XRequestIDKey{}, xid) } imd, ok := metadata.FromIncomingContext(ctx) if !ok { imd = metadata.New(1) imd.Set(DefaultMetadataKey, xid) + ctx = metadata.NewIncomingContext(ctx, imd) } else if _, ok = imd.Get(DefaultMetadataKey); !ok { imd.Set(DefaultMetadataKey, xid) - } else { - iok = true } omd, ok := metadata.FromOutgoingContext(ctx) if !ok { omd = metadata.New(1) omd.Set(DefaultMetadataKey, xid) + ctx = metadata.NewOutgoingContext(ctx, imd) } else if _, ok = omd.Get(DefaultMetadataKey); !ok { omd.Set(DefaultMetadataKey, xid) - } else { - ook = true - } - - if !iok { - ctx = metadata.NewIncomingContext(ctx, imd) - } - if !ook { - ctx = metadata.NewOutgoingContext(ctx, omd) } return ctx, nil diff --git a/requestid_test.go b/requestid_test.go index c62c4b7..a54da92 100644 --- a/requestid_test.go +++ b/requestid_test.go @@ -9,25 +9,29 @@ import ( func TestDefaultMetadataFunc(t *testing.T) { ctx := context.TODO() + var err error - nctx, err := DefaultMetadataFunc(ctx) + ctx, err = DefaultMetadataFunc(ctx) if err != nil { t.Fatalf("%v", err) } - imd, ok := metadata.FromIncomingContext(nctx) + imd, ok := metadata.FromIncomingContext(ctx) if !ok { t.Fatalf("md missing in incoming context") } - omd, ok := metadata.FromOutgoingContext(nctx) + omd, ok := metadata.FromOutgoingContext(ctx) if !ok { t.Fatalf("md missing in outgoing context") } - _, iok := imd.Get(DefaultMetadataKey) - _, ook := omd.Get(DefaultMetadataKey) + iv, iok := imd.Get(DefaultMetadataKey) + ov, ook := omd.Get(DefaultMetadataKey) if !iok || !ook { t.Fatalf("missing metadata key value") } + if iv != ov { + t.Fatalf("invalid metadata key value") + } }