From 927c7ea3c255dfcc7fb627db7e9bc31a72cb3a97 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Tue, 9 Feb 2021 12:46:14 +0300 Subject: [PATCH] metadata: allow to modify metadata via SetXXX functions Signed-off-by: Vasiliy Tolstov --- metadata/context.go | 58 +++++++++++++++++++++++++++++++++------ metadata/metadata.go | 4 +++ metadata/metadata_test.go | 23 ++++++++++++++++ 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/metadata/context.go b/metadata/context.go index c8000fbf..b7540c14 100644 --- a/metadata/context.go +++ b/metadata/context.go @@ -15,8 +15,11 @@ func FromIncomingContext(ctx context.Context) (Metadata, bool) { if ctx == nil { return nil, false } - md, ok := ctx.Value(mdIncomingKey{}).(Metadata) - return md, ok + md, ok := ctx.Value(mdIncomingKey{}).(*rawMetadata) + if !ok { + return nil, false + } + return md.md, ok } // FromOutgoingContext returns metadata from outgoing ctx @@ -25,8 +28,11 @@ func FromOutgoingContext(ctx context.Context) (Metadata, bool) { if ctx == nil { return nil, false } - md, ok := ctx.Value(mdOutgoingKey{}).(Metadata) - return md, ok + md, ok := ctx.Value(mdOutgoingKey{}).(*rawMetadata) + if !ok { + return nil, false + } + return md.md, ok } // FromContext returns metadata from the given context @@ -37,8 +43,11 @@ func FromContext(ctx context.Context) (Metadata, bool) { if ctx == nil { return nil, false } - md, ok := ctx.Value(mdKey{}).(Metadata) - return md, ok + md, ok := ctx.Value(mdKey{}).(*rawMetadata) + if !ok { + return nil, false + } + return md.md, ok } // NewContext creates a new context with the given metadata @@ -48,7 +57,34 @@ func NewContext(ctx context.Context, md Metadata) context.Context { if ctx == nil { ctx = context.Background() } - return context.WithValue(ctx, mdKey{}, md) + ctx = context.WithValue(ctx, mdKey{}, &rawMetadata{md}) + ctx = context.WithValue(ctx, mdIncomingKey{}, &rawMetadata{}) + ctx = context.WithValue(ctx, mdOutgoingKey{}, &rawMetadata{}) + return ctx +} + +// SetOutgoingContext modify outgoing context with given metadata +func SetOutgoingContext(ctx context.Context, md Metadata) bool { + if ctx == nil { + return false + } + if omd, ok := ctx.Value(mdOutgoingKey{}).(*rawMetadata); ok { + omd.md = md + return true + } + return false +} + +// SetIncomingContext modify incoming context with given metadata +func SetIncomingContext(ctx context.Context, md Metadata) bool { + if ctx == nil { + return false + } + if omd, ok := ctx.Value(mdIncomingKey{}).(*rawMetadata); ok { + omd.md = md + return true + } + return false } // NewIncomingContext creates a new context with incoming metadata attached @@ -56,7 +92,9 @@ func NewIncomingContext(ctx context.Context, md Metadata) context.Context { if ctx == nil { ctx = context.Background() } - return context.WithValue(ctx, mdIncomingKey{}, md) + ctx = context.WithValue(ctx, mdIncomingKey{}, &rawMetadata{md}) + ctx = context.WithValue(ctx, mdOutgoingKey{}, &rawMetadata{}) + return ctx } // NewOutgoingContext creates a new context with outcoming metadata attached @@ -64,5 +102,7 @@ func NewOutgoingContext(ctx context.Context, md Metadata) context.Context { if ctx == nil { ctx = context.Background() } - return context.WithValue(ctx, mdOutgoingKey{}, md) + ctx = context.WithValue(ctx, mdOutgoingKey{}, &rawMetadata{md}) + ctx = context.WithValue(ctx, mdIncomingKey{}, &rawMetadata{}) + return ctx } diff --git a/metadata/metadata.go b/metadata/metadata.go index 635e6707..b1b46478 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -16,6 +16,10 @@ var ( // from Transport headers. type Metadata map[string]string +type rawMetadata struct { + md Metadata +} + var ( // defaultMetadataSize used when need to init new Metadata defaultMetadataSize = 2 diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 34d077fe..d1b91efb 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -2,9 +2,32 @@ package metadata import ( "context" + "fmt" "testing" ) +func testCtx(ctx context.Context) { + md := New(2) + md.Set("Key1", "Val1_new") + md.Set("Key3", "Val3") + SetOutgoingContext(ctx, md) +} + +func TestPassing(t *testing.T) { + ctx := context.TODO() + md1 := New(2) + md1.Set("Key1", "Val1") + md1.Set("Key2", "Val2") + + ctx = NewIncomingContext(ctx, md1) + testCtx(ctx) + md, ok := FromOutgoingContext(ctx) + if !ok { + t.Fatalf("missing metadata from outgoing context") + } + fmt.Printf("%#+v\n", md) +} + func TestMerge(t *testing.T) { omd := Metadata{ "key1": "val1",