diff --git a/metadata/metadata.go b/metadata/metadata.go index 606c1f05..58c70c3d 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -29,15 +29,18 @@ func NewContext(ctx context.Context, md Metadata) context.Context { return context.WithValue(ctx, metaKey{}, md) } -// PatchContext : will add/replace source metadata fields with given patch metadata fields -func PatchContext(ctx context.Context, patchMd Metadata) context.Context { +func AppendContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context { md, _ := ctx.Value(metaKey{}).(Metadata) cmd := make(Metadata) for k, v := range md { cmd[k] = v } for k, v := range patchMd { - cmd[k] = v + if _, ok := cmd[k]; ok && !overwrite { + // skip + } else { + cmd[k] = v + } } return context.WithValue(ctx, metaKey{}, cmd) diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index cfd1c1f8..f05e4229 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -2,6 +2,7 @@ package metadata import ( "context" + "reflect" "testing" ) @@ -40,28 +41,42 @@ func TestMetadataContext(t *testing.T) { t.Errorf("Expected metadata length 1 got %d", i) } } -func TestPatchContext(t *testing.T) { - original := Metadata{ - "foo": "bar", +func TestAppendContext(t *testing.T) { + type args struct { + existing Metadata + append Metadata + overwrite bool } - - patch := Metadata{ - "sumo": "demo", + tests := []struct { + name string + args args + want Metadata + }{ + { + name: "matching key, overwrite false", + args: args{ + existing: Metadata{"foo": "bar", "sumo": "demo"}, + append: Metadata{"sumo": "demo2"}, + overwrite: false, + }, + want: Metadata{"foo": "bar", "sumo": "demo"}, + }, + { + name: "matching key, overwrite true", + args: args{ + existing: Metadata{"foo": "bar", "sumo": "demo"}, + append: Metadata{"sumo": "demo2"}, + overwrite: true, + }, + want: Metadata{"foo": "bar", "sumo": "demo2"}, + }, } - ctx := NewContext(context.TODO(), original) - - patchedCtx := PatchContext(ctx, patch) - - patchedMd, ok := FromContext(patchedCtx) - if !ok { - t.Errorf("Unexpected error retrieving metadata, got %t", ok) - } - - if patchedMd["sumo"] != patch["sumo"] { - t.Errorf("Expected key: %s val: %s, got key: %s val: %s", "sumo", patch["sumo"], "sumo", patchedMd["sumo"]) - } - if patchedMd["foo"] != original["foo"] { - t.Errorf("Expected key: %s val: %s, got key: %s val: %s", "foo", original["foo"], "foo", patchedMd["foo"]) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, _ := FromContext(AppendContext(NewContext(context.TODO(), tt.args.existing), tt.args.append, tt.args.overwrite)); !reflect.DeepEqual(got, tt.want) { + t.Errorf("AppendContext() = %v, want %v", got, tt.want) + } + }) } }