diff --git a/metadata/metadata.go b/metadata/metadata.go index 393e0853..380740c7 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -17,12 +17,5 @@ func FromContext(ctx context.Context) (Metadata, bool) { } func NewContext(ctx context.Context, md Metadata) context.Context { - if emd, ok := ctx.Value(metaKey{}).(Metadata); ok { - for k, v := range emd { - if _, ok := md[k]; !ok { - md[k] = v - } - } - } return context.WithValue(ctx, metaKey{}, md) } diff --git a/wrapper.go b/wrapper.go index 961d7793..a3bac9af 100644 --- a/wrapper.go +++ b/wrapper.go @@ -12,17 +12,30 @@ type clientWrapper struct { headers metadata.Metadata } +func (c *clientWrapper) setHeaders(ctx context.Context) context.Context { + md, ok := metadata.FromContext(ctx) + if !ok { + md = metadata.Metadata{} + } + for k, v := range c.headers { + if _, ok := md[k]; !ok { + md[k] = v + } + } + return metadata.NewContext(ctx, md) +} + func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - ctx = metadata.NewContext(ctx, c.headers) + ctx = c.setHeaders(ctx) return c.Client.Call(ctx, req, rsp, opts...) } func (c *clientWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Streamer, error) { - ctx = metadata.NewContext(ctx, c.headers) + ctx = c.setHeaders(ctx) return c.Client.Stream(ctx, req, opts...) } func (c *clientWrapper) Publish(ctx context.Context, p client.Publication, opts ...client.PublishOption) error { - ctx = metadata.NewContext(ctx, c.headers) + ctx = c.setHeaders(ctx) return c.Client.Publish(ctx, p, opts...) } diff --git a/wrapper_test.go b/wrapper_test.go new file mode 100644 index 00000000..439ed10d --- /dev/null +++ b/wrapper_test.go @@ -0,0 +1,55 @@ +package micro + +import ( + "testing" + + "github.com/micro/go-micro/metadata" + + "golang.org/x/net/context" +) + +func TestWrapper(t *testing.T) { + testData := []struct { + existing metadata.Metadata + headers metadata.Metadata + overwrite bool + }{ + { + existing: metadata.Metadata{}, + headers: metadata.Metadata{ + "foo": "bar", + }, + overwrite: true, + }, + { + existing: metadata.Metadata{ + "foo": "bar", + }, + headers: metadata.Metadata{ + "foo": "baz", + }, + overwrite: false, + }, + } + + for _, d := range testData { + c := &clientWrapper{ + headers: d.headers, + } + + ctx := metadata.NewContext(context.Background(), d.existing) + c.setHeaders(ctx) + + md, _ := metadata.FromContext(ctx) + + for k, v := range d.headers { + if d.overwrite && md[k] != v { + t.Fatalf("Expected %s=%s got %s=%s", k, v, k, md[k]) + } + if !d.overwrite && md[k] != d.existing[k] { + t.Fatalf("Expected %s=%s got %s=%s", k, d.existing[k], k, md[k]) + } + } + } + +}