util/reflect: add Equal func with ability to skip some fields #244
| @@ -508,3 +508,74 @@ func FieldName(name string) string { | ||||
|  | ||||
| 	return string(newstr) | ||||
| } | ||||
|  | ||||
| func Equal(src interface{}, dst interface{}, excptFields ...string) bool { | ||||
| 	srcVal := reflect.ValueOf(src) | ||||
| 	dstVal := reflect.ValueOf(dst) | ||||
|  | ||||
| 	switch srcVal.Kind() { | ||||
| 	case reflect.Array, reflect.Slice: | ||||
| 		for i := 0; i < srcVal.Len(); i++ { | ||||
| 			e := srcVal.Index(i).Interface() | ||||
| 			a := dstVal.Index(i).Interface() | ||||
| 			if !Equal(e, a, excptFields...) { | ||||
| 				return false | ||||
| 			} | ||||
| 		} | ||||
| 		return true | ||||
| 	case reflect.Map: | ||||
| 		for i := 0; i < len(srcVal.MapKeys()); i++ { | ||||
| 			key := srcVal.MapKeys()[i] | ||||
| 			keyStr := fmt.Sprintf("%v", key.Interface()) | ||||
| 			if stringContains(keyStr, excptFields) { | ||||
| 				continue | ||||
| 			} | ||||
| 			s := srcVal.MapIndex(key) | ||||
| 			d := dstVal.MapIndex(key) | ||||
| 			if !Equal(s.Interface(), d.Interface(), excptFields...) { | ||||
| 				return false | ||||
| 			} | ||||
| 		} | ||||
| 		return true | ||||
| 	case reflect.Struct, reflect.Interface: | ||||
| 		for i := 0; i < srcVal.NumField(); i++ { | ||||
| 			typeField := srcVal.Type().Field(i) | ||||
| 			if stringContains(typeField.Name, excptFields) { | ||||
| 				continue | ||||
| 			} | ||||
| 			s := srcVal.Field(i) | ||||
| 			d := dstVal.FieldByName(typeField.Name) | ||||
| 			if s.CanInterface() && d.CanInterface() { | ||||
| 				if !Equal(s.Interface(), d.Interface(), excptFields...) { | ||||
| 					return false | ||||
| 				} | ||||
| 			} else { | ||||
| 				return false | ||||
| 			} | ||||
| 		} | ||||
| 		return true | ||||
| 	case reflect.Ptr: | ||||
| 		if srcVal.IsNil() { | ||||
| 			return dstVal.IsNil() | ||||
| 		} | ||||
| 		s := srcVal.Elem() | ||||
| 		d := reflect.Indirect(dstVal) | ||||
| 		if s.CanInterface() && d.CanInterface() { | ||||
| 			return Equal(s.Interface(), d.Interface(), excptFields...) | ||||
| 		} | ||||
| 		return false | ||||
| 	case reflect.String, reflect.Int, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: | ||||
| 		return src == dst | ||||
| 	default: | ||||
| 		return srcVal.Interface() == dstVal.Interface() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func stringContains(a string, list []string) bool { | ||||
| 	for _, b := range list { | ||||
| 		if b == a { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
| @@ -133,3 +133,16 @@ func TestMergeNested(t *testing.T) { | ||||
| 		t.Fatalf("merge error: %#+v", dst.Nested) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEqual(t *testing.T) { | ||||
| 	type tstr struct { | ||||
| 		Key1 string | ||||
| 		Key2 string | ||||
| 	} | ||||
|  | ||||
| 	src := &tstr{Key1: "val1", Key2: "micro:generate"} | ||||
| 	dst := &tstr{Key1: "val1", Key2: "val2"} | ||||
| 	if !Equal(src, dst, "Key2") { | ||||
| 		t.Fatal("invalid Equal test") | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user