diff --git a/requestid.go b/requestid.go index bb9b533..13d5603 100644 --- a/requestid.go +++ b/requestid.go @@ -30,31 +30,51 @@ var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") // DefaultMetadataFunc wil be used if user not provide own func to fill metadata var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { var xid string - var err error - if _, ok := ctx.Value(XRequestIDKey{}).(string); !ok { + cid, cok := ctx.Value(XRequestIDKey{}).(string) + if cok && cid != "" { + xid = cid + } + + imd, iok := metadata.FromIncomingContext(ctx) + if !iok || imd == nil { + imd = metadata.New(1) + ctx = metadata.NewIncomingContext(ctx, imd) + } + + omd, ook := metadata.FromOutgoingContext(ctx) + if !ook || omd == nil { + omd = metadata.New(1) + ctx = metadata.NewOutgoingContext(ctx, omd) + } + + if xid == "" { + var id string + if id, iok = imd.Get(DefaultMetadataKey); iok && id != "" { + xid = id + } + if id, ook = omd.Get(DefaultMetadataKey); ook && id != "" { + xid = id + } + } + + if xid == "" { + var err error xid, err = id.New() if err != nil { return ctx, err } + } + + if !cok { ctx = context.WithValue(ctx, XRequestIDKey{}, xid) } - imd, ok := metadata.FromIncomingContext(ctx) - if !ok { - imd = metadata.New(1) - imd.Set(DefaultMetadataKey, xid) - ctx = metadata.NewIncomingContext(ctx, imd) - } else if _, ok = imd.Get(DefaultMetadataKey); !ok { + if !iok { imd.Set(DefaultMetadataKey, xid) } - omd, ok := metadata.FromOutgoingContext(ctx) - if !ok { - omd = metadata.New(1) - omd.Set(DefaultMetadataKey, xid) - ctx = metadata.NewOutgoingContext(ctx, omd) - } else if _, ok = omd.Get(DefaultMetadataKey); !ok { + if !ook { omd.Set(DefaultMetadataKey, xid) }