AppendContext with overwrite flag

This commit is contained in:
Sumanth Chinthagunta 2019-10-25 08:27:28 -07:00
parent 1f658cfbff
commit 1c6b85e05d
2 changed files with 41 additions and 23 deletions

View File

@ -29,15 +29,18 @@ func NewContext(ctx context.Context, md Metadata) context.Context {
return context.WithValue(ctx, metaKey{}, md) return context.WithValue(ctx, metaKey{}, md)
} }
// PatchContext : will add/replace source metadata fields with given patch metadata fields func AppendContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context {
func PatchContext(ctx context.Context, patchMd Metadata) context.Context {
md, _ := ctx.Value(metaKey{}).(Metadata) md, _ := ctx.Value(metaKey{}).(Metadata)
cmd := make(Metadata) cmd := make(Metadata)
for k, v := range md { for k, v := range md {
cmd[k] = v cmd[k] = v
} }
for k, v := range patchMd { for k, v := range patchMd {
cmd[k] = v if _, ok := cmd[k]; ok && !overwrite {
// skip
} else {
cmd[k] = v
}
} }
return context.WithValue(ctx, metaKey{}, cmd) return context.WithValue(ctx, metaKey{}, cmd)

View File

@ -2,6 +2,7 @@ package metadata
import ( import (
"context" "context"
"reflect"
"testing" "testing"
) )
@ -40,28 +41,42 @@ func TestMetadataContext(t *testing.T) {
t.Errorf("Expected metadata length 1 got %d", i) t.Errorf("Expected metadata length 1 got %d", i)
} }
} }
func TestPatchContext(t *testing.T) {
original := Metadata{ func TestAppendContext(t *testing.T) {
"foo": "bar", type args struct {
existing Metadata
append Metadata
overwrite bool
} }
tests := []struct {
patch := Metadata{ name string
"sumo": "demo", args args
want Metadata
}{
{
name: "matching key, overwrite false",
args: args{
existing: Metadata{"foo": "bar", "sumo": "demo"},
append: Metadata{"sumo": "demo2"},
overwrite: false,
},
want: Metadata{"foo": "bar", "sumo": "demo"},
},
{
name: "matching key, overwrite true",
args: args{
existing: Metadata{"foo": "bar", "sumo": "demo"},
append: Metadata{"sumo": "demo2"},
overwrite: true,
},
want: Metadata{"foo": "bar", "sumo": "demo2"},
},
} }
ctx := NewContext(context.TODO(), original) for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
patchedCtx := PatchContext(ctx, patch) if got, _ := FromContext(AppendContext(NewContext(context.TODO(), tt.args.existing), tt.args.append, tt.args.overwrite)); !reflect.DeepEqual(got, tt.want) {
t.Errorf("AppendContext() = %v, want %v", got, tt.want)
patchedMd, ok := FromContext(patchedCtx) }
if !ok { })
t.Errorf("Unexpected error retrieving metadata, got %t", ok)
}
if patchedMd["sumo"] != patch["sumo"] {
t.Errorf("Expected key: %s val: %s, got key: %s val: %s", "sumo", patch["sumo"], "sumo", patchedMd["sumo"])
}
if patchedMd["foo"] != original["foo"] {
t.Errorf("Expected key: %s val: %s, got key: %s val: %s", "foo", original["foo"], "foo", patchedMd["foo"])
} }
} }