diff --git a/util/reflect/reflect.go b/util/reflect/reflect.go index 031dcc72..7d34b068 100644 --- a/util/reflect/reflect.go +++ b/util/reflect/reflect.go @@ -508,3 +508,74 @@ func FieldName(name string) string { return string(newstr) } + +func Equal(src interface{}, dst interface{}, excptFields ...string) bool { + srcVal := reflect.ValueOf(src) + dstVal := reflect.ValueOf(dst) + + switch srcVal.Kind() { + case reflect.Array, reflect.Slice: + for i := 0; i < srcVal.Len(); i++ { + e := srcVal.Index(i).Interface() + a := dstVal.Index(i).Interface() + if !Equal(e, a, excptFields...) { + return false + } + } + return true + case reflect.Map: + for i := 0; i < len(srcVal.MapKeys()); i++ { + key := srcVal.MapKeys()[i] + keyStr := fmt.Sprintf("%v", key.Interface()) + if stringContains(keyStr, excptFields) { + continue + } + s := srcVal.MapIndex(key) + d := dstVal.MapIndex(key) + if !Equal(s.Interface(), d.Interface(), excptFields...) { + return false + } + } + return true + case reflect.Struct, reflect.Interface: + for i := 0; i < srcVal.NumField(); i++ { + typeField := srcVal.Type().Field(i) + if stringContains(typeField.Name, excptFields) { + continue + } + s := srcVal.Field(i) + d := dstVal.FieldByName(typeField.Name) + if s.CanInterface() && d.CanInterface() { + if !Equal(s.Interface(), d.Interface(), excptFields...) { + return false + } + } else { + return false + } + } + return true + case reflect.Ptr: + if srcVal.IsNil() { + return dstVal.IsNil() + } + s := srcVal.Elem() + d := reflect.Indirect(dstVal) + if s.CanInterface() && d.CanInterface() { + return Equal(s.Interface(), d.Interface(), excptFields...) + } + return false + case reflect.String, reflect.Int, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + return src == dst + default: + return srcVal.Interface() == dstVal.Interface() + } +} + +func stringContains(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} diff --git a/util/reflect/reflect_test.go b/util/reflect/reflect_test.go index b475bc08..a2cef0de 100644 --- a/util/reflect/reflect_test.go +++ b/util/reflect/reflect_test.go @@ -133,3 +133,16 @@ func TestMergeNested(t *testing.T) { t.Fatalf("merge error: %#+v", dst.Nested) } } + +func TestEqual(t *testing.T) { + type tstr struct { + Key1 string + Key2 string + } + + src := &tstr{Key1: "val1", Key2: "micro:generate"} + dst := &tstr{Key1: "val1", Key2: "val2"} + if !Equal(src, dst, "Key2") { + t.Fatal("invalid Equal test") + } +}