From b4092c6619c83e7a66332e9cc6a7463d0c13a2e1 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Fri, 5 Feb 2021 18:27:16 +0300 Subject: [PATCH] util/reflect: improve merge for map Signed-off-by: Vasiliy Tolstov --- util/reflect/reflect.go | 189 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) diff --git a/util/reflect/reflect.go b/util/reflect/reflect.go index 185a192e..949fdf9e 100644 --- a/util/reflect/reflect.go +++ b/util/reflect/reflect.go @@ -2,7 +2,10 @@ package reflect import ( "errors" + "fmt" "reflect" + "strconv" + "strings" ) var ( @@ -119,3 +122,189 @@ func CopyFrom(a, b interface{}) { } } } + +func MergeMap(a interface{}, b map[string]interface{}) error { + var err error + + // preprocess map + nb := make(map[string]interface{}, len(b)) + for k, v := range b { + ps := strings.Split(k, ".") + if len(ps) == 1 { + nb[k] = v + continue + } + em := make(map[string]interface{}) + em[ps[len(ps)-1]] = v + for i := len(ps) - 2; i > 0; i-- { + nm := make(map[string]interface{}) + nm[ps[i]] = em + em = nm + } + if vm, ok := nb[ps[0]]; ok { + // nested map + nm := vm.(map[string]interface{}) + for vk, vv := range em { + nm[vk] = vv + } + nb[ps[0]] = nm + } else { + nb[ps[0]] = em + } + } + + ta := reflect.TypeOf(a) + if ta.Kind() == reflect.Ptr { + ta = ta.Elem() + } + va := reflect.ValueOf(a) + if va.Kind() == reflect.Ptr { + va = va.Elem() + } + + for mk, mv := range nb { + vmv := reflect.ValueOf(mv) + // tmv := reflect.TypeOf(mv) + name := strings.Title(mk) + fva := va.FieldByName(name) + fta, found := ta.FieldByName(name) + if !found || !fva.IsValid() || !fva.CanSet() || fta.PkgPath != "" { + continue + } + // fast path via direct assign + if vmv.Type().AssignableTo(fta.Type) { + fva.Set(vmv) + continue + } + switch getKind(fva) { + case reflect.Bool: + err = mergeBool(fva, vmv) + case reflect.String: + err = mergeString(fva, vmv) + case reflect.Int: + err = mergeInt(fva, vmv) + case reflect.Uint: + err = mergeUint(fva, vmv) + case reflect.Float64: + err = mergeFloat(fva, vmv) + } + if err != nil { + return err + } + } + return nil +} + +func mergeBool(va, vb reflect.Value) error { + switch getKind(vb) { + case reflect.Int: + if vb.Int() == 1 { + va.SetBool(true) + } + case reflect.Uint: + if vb.Uint() == 1 { + va.SetBool(true) + } + case reflect.Float64: + if vb.Float() == 1 { + va.SetBool(true) + } + case reflect.String: + if vb.String() == "1" || vb.String() == "true" { + vb.SetBool(true) + } + default: + return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) + } + return nil +} + +func mergeString(va, vb reflect.Value) error { + switch getKind(vb) { + case reflect.Int: + va.SetString(fmt.Sprintf("%d", vb.Int())) + case reflect.Uint: + va.SetString(fmt.Sprintf("%d", vb.Uint())) + case reflect.Float64: + va.SetString(fmt.Sprintf("%f", vb.Float())) + case reflect.String: + va.Set(vb) + default: + return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) + } + return nil +} + +func mergeInt(va, vb reflect.Value) error { + switch getKind(vb) { + case reflect.Int: + va.Set(vb) + case reflect.Uint: + va.SetInt(int64(vb.Uint())) + case reflect.Float64: + va.SetInt(int64(vb.Float())) + case reflect.String: + if f, err := strconv.ParseInt(vb.String(), 10, va.Type().Bits()); err != nil { + return err + } else { + va.SetInt(f) + } + default: + return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) + } + return nil +} + +func mergeUint(va, vb reflect.Value) error { + switch getKind(vb) { + case reflect.Int: + va.SetUint(uint64(vb.Int())) + case reflect.Uint: + va.Set(vb) + case reflect.Float64: + va.SetUint(uint64(vb.Float())) + case reflect.String: + if f, err := strconv.ParseUint(vb.String(), 10, va.Type().Bits()); err != nil { + return err + } else { + va.SetUint(f) + } + default: + return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) + } + return nil +} + +func mergeFloat(va, vb reflect.Value) error { + switch getKind(vb) { + case reflect.Int: + va.SetFloat(float64(vb.Int())) + case reflect.Uint: + va.SetFloat(float64(vb.Uint())) + case reflect.Float64: + va.Set(vb) + case reflect.String: + if f, err := strconv.ParseFloat(vb.String(), va.Type().Bits()); err != nil { + return err + } else { + va.SetFloat(f) + } + default: + return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) + } + + return nil +} + +func getKind(val reflect.Value) reflect.Kind { + kind := val.Kind() + switch { + case kind >= reflect.Int && kind <= reflect.Int64: + return reflect.Int + case kind >= reflect.Uint && kind <= reflect.Uint64: + return reflect.Uint + case kind >= reflect.Float32 && kind <= reflect.Float64: + return reflect.Float64 + } + return kind +}