From 1de9911b7367126d1ae69f3da4ee91a6f40ee955 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Sat, 6 Feb 2021 18:13:43 +0300 Subject: [PATCH] util/reflect: add missing types for merge Signed-off-by: Vasiliy Tolstov --- util/reflect/reflect.go | 329 ++++++++++++++++++++++++++++++++--- util/reflect/reflect_test.go | 45 +++++ 2 files changed, 353 insertions(+), 21 deletions(-) create mode 100644 util/reflect/reflect_test.go diff --git a/util/reflect/reflect.go b/util/reflect/reflect.go index 949fdf9e..d3d648d9 100644 --- a/util/reflect/reflect.go +++ b/util/reflect/reflect.go @@ -3,26 +3,52 @@ package reflect import ( "errors" "fmt" + "net/url" "reflect" + "regexp" "strconv" "strings" + "unicode" ) var ( + bracketSplitter = regexp.MustCompile(`\[|\]`) ErrInvalidStruct = errors.New("invalid struct specified") + ErrInvalidParam = errors.New("invalid url query param provided") ) +func fieldName(name string) string { + newstr := make([]rune, 0) + upper := false + for idx, chr := range name { + if idx == 0 { + upper = true + } else if chr == '_' { + upper = true + continue + } + if upper { + newstr = append(newstr, unicode.ToUpper(chr)) + } else { + newstr = append(newstr, chr) + } + upper = false + } + + return string(newstr) +} + func IsEmpty(v reflect.Value) bool { - switch v.Kind() { + switch getKind(v) { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return v.Len() == 0 case reflect.Bool: return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case reflect.Int: return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + case reflect.Uint: return v.Uint() == 0 - case reflect.Float32, reflect.Float64: + case reflect.Float32: return v.Float() == 0 case reflect.Interface, reflect.Ptr: if v.IsNil() { @@ -123,12 +149,28 @@ func CopyFrom(a, b interface{}) { } } -func MergeMap(a interface{}, b map[string]interface{}) error { - var err error +func URLMap(query string) (map[string]interface{}, error) { + var ( + mp interface{} = make(map[string]interface{}) + ) + params := strings.Split(query, "&") + + for _, part := range params { + tm, err := queryToMap(part) + if err != nil { + return nil, err + } + mp = merge(mp, tm) + } + + return mp.(map[string]interface{}), nil +} + +func FlattenMap(a map[string]interface{}) map[string]interface{} { // preprocess map - nb := make(map[string]interface{}, len(b)) - for k, v := range b { + nb := make(map[string]interface{}, len(a)) + for k, v := range a { ps := strings.Split(k, ".") if len(ps) == 1 { nb[k] = v @@ -152,6 +194,11 @@ func MergeMap(a interface{}, b map[string]interface{}) error { nb[ps[0]] = em } } + return nb +} + +func MergeMap(a interface{}, b map[string]interface{}) error { + var err error ta := reflect.TypeOf(a) if ta.Kind() == reflect.Ptr { @@ -162,17 +209,16 @@ func MergeMap(a interface{}, b map[string]interface{}) error { va = va.Elem() } - for mk, mv := range nb { + for mk, mv := range b { vmv := reflect.ValueOf(mv) - // tmv := reflect.TypeOf(mv) - name := strings.Title(mk) + name := fieldName(mk) fva := va.FieldByName(name) fta, found := ta.FieldByName(name) if !found || !fva.IsValid() || !fva.CanSet() || fta.PkgPath != "" { continue } // fast path via direct assign - if vmv.Type().AssignableTo(fta.Type) { + if vmv.Type().AssignableTo(fta.Type) && !IsEmpty(vmv) { fva.Set(vmv) continue } @@ -185,8 +231,37 @@ func MergeMap(a interface{}, b map[string]interface{}) error { err = mergeInt(fva, vmv) case reflect.Uint: err = mergeUint(fva, vmv) - case reflect.Float64: + case reflect.Float32: err = mergeFloat(fva, vmv) + case reflect.Array: + //fmt.Printf("Array %#+v %#+v\n", fva, vmv) + case reflect.Slice: + err = mergeSlice(fva, vmv) + case reflect.Ptr: + if fva.IsNil() { + fva.Set(reflect.New(fva.Type().Elem())) + if fva.Elem().Type().Kind() == reflect.Struct { + for i := 0; i < fva.Elem().NumField(); i++ { + field := fva.Elem().Field(i) + if field.Type().Kind() == reflect.Ptr && field.IsNil() && fva.Elem().Type().Field(i).Anonymous == true { + field.Set(reflect.New(field.Type().Elem())) + } + } + } + } + if nmp, ok := vmv.Interface().(map[string]interface{}); ok { + err = MergeMap(fva.Interface(), nmp) + } else { + err = fmt.Errorf("cant fill") + } + case reflect.Struct: + if nmp, ok := vmv.Interface().(map[string]interface{}); ok { + err = MergeMap(fva.Interface(), nmp) + } else { + err = fmt.Errorf("cant fill") + } + case reflect.Map: + //fmt.Printf("Map %#+v %#+v\n", fva, vmv) } if err != nil { return err @@ -195,6 +270,75 @@ func MergeMap(a interface{}, b map[string]interface{}) error { return nil } +func mergeSlice(va, vb reflect.Value) error { + switch getKind(vb) { + /* + case reflect.Int: + if vb.Int() == 1 { + va.SetBool(true) + } + case reflect.Uint: + if vb.Uint() == 1 { + va.SetBool(true) + } + case reflect.Float64: + if vb.Float() == 1 { + va.SetBool(true) + } + */ + case reflect.String: + var err error + fn := func(c rune) bool { return c == ',' || c == ';' || c == ' ' } + slice := strings.FieldsFunc(vb.String(), fn) + va.Set(reflect.MakeSlice(va.Type(), len(slice), len(slice))) + for idx, sl := range slice { + vl := reflect.ValueOf(sl) + switch va.Type().Elem().Kind() { + case reflect.Bool: + err = mergeBool(va.Index(idx), vl) + case reflect.String: + err = mergeString(va.Index(idx), vl) + case reflect.Ptr: + if va.Index(idx).IsNil() { + va.Index(idx).Set(reflect.New(va.Index(idx).Type().Elem())) + } + switch va.Type().Elem().String() { + case "*wrapperspb.BoolValue": + if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() { + err = mergeBool(eva, vl) + } + case "*wrapperspb.BytesValue": + if eva := va.Index(idx).FieldByName("Value"); eva.IsValid() { + err = mergeUint(eva, vl) + } + case "*wrapperspb.DoubleValue", "*wrapperspb.FloatValue": + if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() { + err = mergeFloat(eva, vl) + } + case "*wrapperspb.Int32Value", "*wrapperspb.Int64Value": + if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() { + err = mergeInt(eva, vl) + } + case "*wrapperspb.StringValue": + if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() { + err = mergeString(eva, vl) + } + case "*wrapperspb.UInt32Value", "*wrapperspb.UInt64Value": + if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() { + err = mergeUint(eva, vl) + } + } + } + if err != nil { + return err + } + } + default: + return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) + } + return nil +} + func mergeBool(va, vb reflect.Value) error { switch getKind(vb) { case reflect.Int: @@ -205,13 +349,15 @@ func mergeBool(va, vb reflect.Value) error { if vb.Uint() == 1 { va.SetBool(true) } - case reflect.Float64: + case reflect.Float32: if vb.Float() == 1 { va.SetBool(true) } case reflect.String: - if vb.String() == "1" || vb.String() == "true" { - vb.SetBool(true) + if b, err := strconv.ParseBool(vb.String()); err != nil { + return err + } else { + va.SetBool(b) } default: return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) @@ -225,7 +371,7 @@ func mergeString(va, vb reflect.Value) error { va.SetString(fmt.Sprintf("%d", vb.Int())) case reflect.Uint: va.SetString(fmt.Sprintf("%d", vb.Uint())) - case reflect.Float64: + case reflect.Float32: va.SetString(fmt.Sprintf("%f", vb.Float())) case reflect.String: va.Set(vb) @@ -241,7 +387,7 @@ func mergeInt(va, vb reflect.Value) error { va.Set(vb) case reflect.Uint: va.SetInt(int64(vb.Uint())) - case reflect.Float64: + case reflect.Float32: va.SetInt(int64(vb.Float())) case reflect.String: if f, err := strconv.ParseInt(vb.String(), 10, va.Type().Bits()); err != nil { @@ -261,7 +407,7 @@ func mergeUint(va, vb reflect.Value) error { va.SetUint(uint64(vb.Int())) case reflect.Uint: va.Set(vb) - case reflect.Float64: + case reflect.Float32: va.SetUint(uint64(vb.Float())) case reflect.String: if f, err := strconv.ParseUint(vb.String(), 10, va.Type().Bits()); err != nil { @@ -281,7 +427,7 @@ func mergeFloat(va, vb reflect.Value) error { va.SetFloat(float64(vb.Int())) case reflect.Uint: va.SetFloat(float64(vb.Uint())) - case reflect.Float64: + case reflect.Float32: va.Set(vb) case reflect.String: if f, err := strconv.ParseFloat(vb.String(), va.Type().Bits()); err != nil { @@ -304,7 +450,148 @@ func getKind(val reflect.Value) reflect.Kind { case kind >= reflect.Uint && kind <= reflect.Uint64: return reflect.Uint case kind >= reflect.Float32 && kind <= reflect.Float64: - return reflect.Float64 + return reflect.Float32 } return kind } + +func btSplitter(str string) []string { + r := bracketSplitter.Split(str, -1) + for idx, s := range r { + if len(s) == 0 { + if len(r) > idx+1 { + copy(r[idx:], r[idx+1:]) + r = r[:len(r)-1] + } + } + } + return r +} + +// queryToMap turns something like a[b][c]=4 into +// map[string]interface{}{ +// "a": map[string]interface{}{ +// "b": map[string]interface{}{ +// "c": 4, +// }, +// }, +// } +func queryToMap(param string) (map[string]interface{}, error) { + rawKey, rawValue, err := splitKeyAndValue(param) + if err != nil { + return nil, err + } + rawValue, err = url.QueryUnescape(rawValue) + if err != nil { + return nil, err + } + rawKey, err = url.QueryUnescape(rawKey) + if err != nil { + return nil, err + } + + pieces := btSplitter(rawKey) + key := pieces[0] + + // If len==1 then rawKey has no [] chars and we can just + // decode this as key=value into {key: value} + if len(pieces) == 1 { + return map[string]interface{}{ + key: rawValue, + }, nil + } + + // If len > 1 then we have something like a[b][c]=2 + // so we need to turn this into {"a": {"b": {"c": 2}}} + // To do this we break our key into two pieces: + // a and b[c] + // and then we set {"a": queryToMap("b[c]", value)} + ret := make(map[string]interface{}) + ret[key], err = queryToMap(buildNewKey(rawKey) + "=" + rawValue) + if err != nil { + return nil, err + } + + // When URL params have a set of empty brackets (eg a[]=1) + // it is assumed to be an array. This will get us the + // correct value for the array item and return it as an + // []interface{} so that it can be merged properly. + if pieces[1] == "" { + temp := ret[key].(map[string]interface{}) + ret[key] = []interface{}{temp[""]} + } + return ret, nil +} + +// buildNewKey will take something like: +// origKey = "bar[one][two]" +// pieces = [bar one two ] +// and return "one[two]" +func buildNewKey(origKey string) string { + pieces := btSplitter(origKey) + + ret := origKey[len(pieces[0])+1:] + ret = ret[:len(pieces[1])] + ret[len(pieces[1])+1:] + return ret +} + +// splitKeyAndValue splits a URL param at the last equal +// sign and returns the two strings. If no equal sign is +// found, the ErrInvalidParam error is returned. +func splitKeyAndValue(param string) (string, string, error) { + li := strings.LastIndex(param, "=") + if li == -1 { + return "", "", ErrInvalidParam + } + return param[:li], param[li+1:], nil +} + +// merge merges a with b if they are either both slices +// or map[string]interface{} types. Otherwise it returns b. +func merge(a interface{}, b interface{}) interface{} { + if av, aok := a.(map[string]interface{}); aok { + if bv, bok := b.(map[string]interface{}); bok { + return mergeMapIface(av, bv) + } + } + if av, aok := a.([]interface{}); aok { + if bv, bok := b.([]interface{}); bok { + return mergeSliceIface(av, bv) + } + } + + va := reflect.ValueOf(a) + vb := reflect.ValueOf(b) + if (va.Type().Kind() == reflect.Slice) && (va.Type().Elem().Kind() == vb.Type().Kind() || vb.Type().ConvertibleTo(va.Type().Elem())) { + va = reflect.Append(va, vb.Convert(va.Type().Elem())) + return va.Interface() + } + + return b +} + +// mergeMap merges a with b, attempting to merge any nested +// values in nested maps but eventually overwriting anything +// in a that can't be merged with whatever is in b. +func mergeMapIface(a map[string]interface{}, b map[string]interface{}) map[string]interface{} { + for bK, bV := range b { + if aV, ok := a[bK]; ok { + if (reflect.ValueOf(aV).Type().Kind() == reflect.ValueOf(bV).Type().Kind()) || + ((reflect.ValueOf(aV).Type().Kind() == reflect.Slice) && reflect.ValueOf(aV).Type().Elem().Kind() == reflect.ValueOf(bV).Type().Kind()) { + nV := []interface{}{aV, bV} + a[bK] = nV + } else { + a[bK] = merge(a[bK], bV) + } + } else { + a[bK] = bV + } + } + return a +} + +// mergeSlice merges a with b and returns the result. +func mergeSliceIface(a []interface{}, b []interface{}) []interface{} { + a = append(a, b...) + return a +} diff --git a/util/reflect/reflect_test.go b/util/reflect/reflect_test.go new file mode 100644 index 00000000..68cf9e9c --- /dev/null +++ b/util/reflect/reflect_test.go @@ -0,0 +1,45 @@ +package reflect + +import ( + "net/url" + "testing" +) + +func TestURLSliceVars(t *testing.T) { + u, err := url.Parse("http://localhost/v1/test/call/my_name?key=arg1&key=arg2&key=arg3") + if err != nil { + t.Fatal(err) + } + + mp, err := URLMap(u.RawQuery) + if err != nil { + t.Fatal(err) + } + + v, ok := mp["key"] + if !ok { + t.Fatalf("key not exists: %#+v", mp) + } + + vm, ok := v.([]interface{}) + if !ok { + t.Fatalf("invalid key value") + } + + if len(vm) != 3 { + t.Fatalf("missing key value: %#+v", mp) + } +} + +func TestURLVars(t *testing.T) { + u, err := url.Parse("http://localhost/v1/test/call/my_name?req=key&arg1=arg1&arg2=12345&nested.string_args=str1&nested.string_args=str2&arg2=54321") + if err != nil { + t.Fatal(err) + } + + mp, err := URLMap(u.RawQuery) + if err != nil { + t.Fatal(err) + } + _ = mp +}