util/reflect: add missing types for merge

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2021-02-06 18:13:43 +03:00
parent b4092c6619
commit 1de9911b73
2 changed files with 353 additions and 21 deletions

View File

@ -3,26 +3,52 @@ package reflect
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/url"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"unicode"
) )
var ( var (
bracketSplitter = regexp.MustCompile(`\[|\]`)
ErrInvalidStruct = errors.New("invalid struct specified") ErrInvalidStruct = errors.New("invalid struct specified")
ErrInvalidParam = errors.New("invalid url query param provided")
) )
func fieldName(name string) string {
newstr := make([]rune, 0)
upper := false
for idx, chr := range name {
if idx == 0 {
upper = true
} else if chr == '_' {
upper = true
continue
}
if upper {
newstr = append(newstr, unicode.ToUpper(chr))
} else {
newstr = append(newstr, chr)
}
upper = false
}
return string(newstr)
}
func IsEmpty(v reflect.Value) bool { func IsEmpty(v reflect.Value) bool {
switch v.Kind() { switch getKind(v) {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String: case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0 return v.Len() == 0
case reflect.Bool: case reflect.Bool:
return !v.Bool() return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int:
return v.Int() == 0 return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint:
return v.Uint() == 0 return v.Uint() == 0
case reflect.Float32, reflect.Float64: case reflect.Float32:
return v.Float() == 0 return v.Float() == 0
case reflect.Interface, reflect.Ptr: case reflect.Interface, reflect.Ptr:
if v.IsNil() { if v.IsNil() {
@ -123,12 +149,28 @@ func CopyFrom(a, b interface{}) {
} }
} }
func MergeMap(a interface{}, b map[string]interface{}) error { func URLMap(query string) (map[string]interface{}, error) {
var err error var (
mp interface{} = make(map[string]interface{})
)
params := strings.Split(query, "&")
for _, part := range params {
tm, err := queryToMap(part)
if err != nil {
return nil, err
}
mp = merge(mp, tm)
}
return mp.(map[string]interface{}), nil
}
func FlattenMap(a map[string]interface{}) map[string]interface{} {
// preprocess map // preprocess map
nb := make(map[string]interface{}, len(b)) nb := make(map[string]interface{}, len(a))
for k, v := range b { for k, v := range a {
ps := strings.Split(k, ".") ps := strings.Split(k, ".")
if len(ps) == 1 { if len(ps) == 1 {
nb[k] = v nb[k] = v
@ -152,6 +194,11 @@ func MergeMap(a interface{}, b map[string]interface{}) error {
nb[ps[0]] = em nb[ps[0]] = em
} }
} }
return nb
}
func MergeMap(a interface{}, b map[string]interface{}) error {
var err error
ta := reflect.TypeOf(a) ta := reflect.TypeOf(a)
if ta.Kind() == reflect.Ptr { if ta.Kind() == reflect.Ptr {
@ -162,17 +209,16 @@ func MergeMap(a interface{}, b map[string]interface{}) error {
va = va.Elem() va = va.Elem()
} }
for mk, mv := range nb { for mk, mv := range b {
vmv := reflect.ValueOf(mv) vmv := reflect.ValueOf(mv)
// tmv := reflect.TypeOf(mv) name := fieldName(mk)
name := strings.Title(mk)
fva := va.FieldByName(name) fva := va.FieldByName(name)
fta, found := ta.FieldByName(name) fta, found := ta.FieldByName(name)
if !found || !fva.IsValid() || !fva.CanSet() || fta.PkgPath != "" { if !found || !fva.IsValid() || !fva.CanSet() || fta.PkgPath != "" {
continue continue
} }
// fast path via direct assign // fast path via direct assign
if vmv.Type().AssignableTo(fta.Type) { if vmv.Type().AssignableTo(fta.Type) && !IsEmpty(vmv) {
fva.Set(vmv) fva.Set(vmv)
continue continue
} }
@ -185,8 +231,37 @@ func MergeMap(a interface{}, b map[string]interface{}) error {
err = mergeInt(fva, vmv) err = mergeInt(fva, vmv)
case reflect.Uint: case reflect.Uint:
err = mergeUint(fva, vmv) err = mergeUint(fva, vmv)
case reflect.Float64: case reflect.Float32:
err = mergeFloat(fva, vmv) err = mergeFloat(fva, vmv)
case reflect.Array:
//fmt.Printf("Array %#+v %#+v\n", fva, vmv)
case reflect.Slice:
err = mergeSlice(fva, vmv)
case reflect.Ptr:
if fva.IsNil() {
fva.Set(reflect.New(fva.Type().Elem()))
if fva.Elem().Type().Kind() == reflect.Struct {
for i := 0; i < fva.Elem().NumField(); i++ {
field := fva.Elem().Field(i)
if field.Type().Kind() == reflect.Ptr && field.IsNil() && fva.Elem().Type().Field(i).Anonymous == true {
field.Set(reflect.New(field.Type().Elem()))
}
}
}
}
if nmp, ok := vmv.Interface().(map[string]interface{}); ok {
err = MergeMap(fva.Interface(), nmp)
} else {
err = fmt.Errorf("cant fill")
}
case reflect.Struct:
if nmp, ok := vmv.Interface().(map[string]interface{}); ok {
err = MergeMap(fva.Interface(), nmp)
} else {
err = fmt.Errorf("cant fill")
}
case reflect.Map:
//fmt.Printf("Map %#+v %#+v\n", fva, vmv)
} }
if err != nil { if err != nil {
return err return err
@ -195,6 +270,75 @@ func MergeMap(a interface{}, b map[string]interface{}) error {
return nil return nil
} }
func mergeSlice(va, vb reflect.Value) error {
switch getKind(vb) {
/*
case reflect.Int:
if vb.Int() == 1 {
va.SetBool(true)
}
case reflect.Uint:
if vb.Uint() == 1 {
va.SetBool(true)
}
case reflect.Float64:
if vb.Float() == 1 {
va.SetBool(true)
}
*/
case reflect.String:
var err error
fn := func(c rune) bool { return c == ',' || c == ';' || c == ' ' }
slice := strings.FieldsFunc(vb.String(), fn)
va.Set(reflect.MakeSlice(va.Type(), len(slice), len(slice)))
for idx, sl := range slice {
vl := reflect.ValueOf(sl)
switch va.Type().Elem().Kind() {
case reflect.Bool:
err = mergeBool(va.Index(idx), vl)
case reflect.String:
err = mergeString(va.Index(idx), vl)
case reflect.Ptr:
if va.Index(idx).IsNil() {
va.Index(idx).Set(reflect.New(va.Index(idx).Type().Elem()))
}
switch va.Type().Elem().String() {
case "*wrapperspb.BoolValue":
if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() {
err = mergeBool(eva, vl)
}
case "*wrapperspb.BytesValue":
if eva := va.Index(idx).FieldByName("Value"); eva.IsValid() {
err = mergeUint(eva, vl)
}
case "*wrapperspb.DoubleValue", "*wrapperspb.FloatValue":
if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() {
err = mergeFloat(eva, vl)
}
case "*wrapperspb.Int32Value", "*wrapperspb.Int64Value":
if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() {
err = mergeInt(eva, vl)
}
case "*wrapperspb.StringValue":
if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() {
err = mergeString(eva, vl)
}
case "*wrapperspb.UInt32Value", "*wrapperspb.UInt64Value":
if eva := reflect.Indirect(va.Index(idx)).FieldByName("Value"); eva.IsValid() {
err = mergeUint(eva, vl)
}
}
}
if err != nil {
return err
}
}
default:
return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind())
}
return nil
}
func mergeBool(va, vb reflect.Value) error { func mergeBool(va, vb reflect.Value) error {
switch getKind(vb) { switch getKind(vb) {
case reflect.Int: case reflect.Int:
@ -205,13 +349,15 @@ func mergeBool(va, vb reflect.Value) error {
if vb.Uint() == 1 { if vb.Uint() == 1 {
va.SetBool(true) va.SetBool(true)
} }
case reflect.Float64: case reflect.Float32:
if vb.Float() == 1 { if vb.Float() == 1 {
va.SetBool(true) va.SetBool(true)
} }
case reflect.String: case reflect.String:
if vb.String() == "1" || vb.String() == "true" { if b, err := strconv.ParseBool(vb.String()); err != nil {
vb.SetBool(true) return err
} else {
va.SetBool(b)
} }
default: default:
return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind()) return fmt.Errorf("cant merge %v %s with %v %s", va, va.Kind(), vb, vb.Kind())
@ -225,7 +371,7 @@ func mergeString(va, vb reflect.Value) error {
va.SetString(fmt.Sprintf("%d", vb.Int())) va.SetString(fmt.Sprintf("%d", vb.Int()))
case reflect.Uint: case reflect.Uint:
va.SetString(fmt.Sprintf("%d", vb.Uint())) va.SetString(fmt.Sprintf("%d", vb.Uint()))
case reflect.Float64: case reflect.Float32:
va.SetString(fmt.Sprintf("%f", vb.Float())) va.SetString(fmt.Sprintf("%f", vb.Float()))
case reflect.String: case reflect.String:
va.Set(vb) va.Set(vb)
@ -241,7 +387,7 @@ func mergeInt(va, vb reflect.Value) error {
va.Set(vb) va.Set(vb)
case reflect.Uint: case reflect.Uint:
va.SetInt(int64(vb.Uint())) va.SetInt(int64(vb.Uint()))
case reflect.Float64: case reflect.Float32:
va.SetInt(int64(vb.Float())) va.SetInt(int64(vb.Float()))
case reflect.String: case reflect.String:
if f, err := strconv.ParseInt(vb.String(), 10, va.Type().Bits()); err != nil { if f, err := strconv.ParseInt(vb.String(), 10, va.Type().Bits()); err != nil {
@ -261,7 +407,7 @@ func mergeUint(va, vb reflect.Value) error {
va.SetUint(uint64(vb.Int())) va.SetUint(uint64(vb.Int()))
case reflect.Uint: case reflect.Uint:
va.Set(vb) va.Set(vb)
case reflect.Float64: case reflect.Float32:
va.SetUint(uint64(vb.Float())) va.SetUint(uint64(vb.Float()))
case reflect.String: case reflect.String:
if f, err := strconv.ParseUint(vb.String(), 10, va.Type().Bits()); err != nil { if f, err := strconv.ParseUint(vb.String(), 10, va.Type().Bits()); err != nil {
@ -281,7 +427,7 @@ func mergeFloat(va, vb reflect.Value) error {
va.SetFloat(float64(vb.Int())) va.SetFloat(float64(vb.Int()))
case reflect.Uint: case reflect.Uint:
va.SetFloat(float64(vb.Uint())) va.SetFloat(float64(vb.Uint()))
case reflect.Float64: case reflect.Float32:
va.Set(vb) va.Set(vb)
case reflect.String: case reflect.String:
if f, err := strconv.ParseFloat(vb.String(), va.Type().Bits()); err != nil { if f, err := strconv.ParseFloat(vb.String(), va.Type().Bits()); err != nil {
@ -304,7 +450,148 @@ func getKind(val reflect.Value) reflect.Kind {
case kind >= reflect.Uint && kind <= reflect.Uint64: case kind >= reflect.Uint && kind <= reflect.Uint64:
return reflect.Uint return reflect.Uint
case kind >= reflect.Float32 && kind <= reflect.Float64: case kind >= reflect.Float32 && kind <= reflect.Float64:
return reflect.Float64 return reflect.Float32
} }
return kind return kind
} }
func btSplitter(str string) []string {
r := bracketSplitter.Split(str, -1)
for idx, s := range r {
if len(s) == 0 {
if len(r) > idx+1 {
copy(r[idx:], r[idx+1:])
r = r[:len(r)-1]
}
}
}
return r
}
// queryToMap turns something like a[b][c]=4 into
// map[string]interface{}{
// "a": map[string]interface{}{
// "b": map[string]interface{}{
// "c": 4,
// },
// },
// }
func queryToMap(param string) (map[string]interface{}, error) {
rawKey, rawValue, err := splitKeyAndValue(param)
if err != nil {
return nil, err
}
rawValue, err = url.QueryUnescape(rawValue)
if err != nil {
return nil, err
}
rawKey, err = url.QueryUnescape(rawKey)
if err != nil {
return nil, err
}
pieces := btSplitter(rawKey)
key := pieces[0]
// If len==1 then rawKey has no [] chars and we can just
// decode this as key=value into {key: value}
if len(pieces) == 1 {
return map[string]interface{}{
key: rawValue,
}, nil
}
// If len > 1 then we have something like a[b][c]=2
// so we need to turn this into {"a": {"b": {"c": 2}}}
// To do this we break our key into two pieces:
// a and b[c]
// and then we set {"a": queryToMap("b[c]", value)}
ret := make(map[string]interface{})
ret[key], err = queryToMap(buildNewKey(rawKey) + "=" + rawValue)
if err != nil {
return nil, err
}
// When URL params have a set of empty brackets (eg a[]=1)
// it is assumed to be an array. This will get us the
// correct value for the array item and return it as an
// []interface{} so that it can be merged properly.
if pieces[1] == "" {
temp := ret[key].(map[string]interface{})
ret[key] = []interface{}{temp[""]}
}
return ret, nil
}
// buildNewKey will take something like:
// origKey = "bar[one][two]"
// pieces = [bar one two ]
// and return "one[two]"
func buildNewKey(origKey string) string {
pieces := btSplitter(origKey)
ret := origKey[len(pieces[0])+1:]
ret = ret[:len(pieces[1])] + ret[len(pieces[1])+1:]
return ret
}
// splitKeyAndValue splits a URL param at the last equal
// sign and returns the two strings. If no equal sign is
// found, the ErrInvalidParam error is returned.
func splitKeyAndValue(param string) (string, string, error) {
li := strings.LastIndex(param, "=")
if li == -1 {
return "", "", ErrInvalidParam
}
return param[:li], param[li+1:], nil
}
// merge merges a with b if they are either both slices
// or map[string]interface{} types. Otherwise it returns b.
func merge(a interface{}, b interface{}) interface{} {
if av, aok := a.(map[string]interface{}); aok {
if bv, bok := b.(map[string]interface{}); bok {
return mergeMapIface(av, bv)
}
}
if av, aok := a.([]interface{}); aok {
if bv, bok := b.([]interface{}); bok {
return mergeSliceIface(av, bv)
}
}
va := reflect.ValueOf(a)
vb := reflect.ValueOf(b)
if (va.Type().Kind() == reflect.Slice) && (va.Type().Elem().Kind() == vb.Type().Kind() || vb.Type().ConvertibleTo(va.Type().Elem())) {
va = reflect.Append(va, vb.Convert(va.Type().Elem()))
return va.Interface()
}
return b
}
// mergeMap merges a with b, attempting to merge any nested
// values in nested maps but eventually overwriting anything
// in a that can't be merged with whatever is in b.
func mergeMapIface(a map[string]interface{}, b map[string]interface{}) map[string]interface{} {
for bK, bV := range b {
if aV, ok := a[bK]; ok {
if (reflect.ValueOf(aV).Type().Kind() == reflect.ValueOf(bV).Type().Kind()) ||
((reflect.ValueOf(aV).Type().Kind() == reflect.Slice) && reflect.ValueOf(aV).Type().Elem().Kind() == reflect.ValueOf(bV).Type().Kind()) {
nV := []interface{}{aV, bV}
a[bK] = nV
} else {
a[bK] = merge(a[bK], bV)
}
} else {
a[bK] = bV
}
}
return a
}
// mergeSlice merges a with b and returns the result.
func mergeSliceIface(a []interface{}, b []interface{}) []interface{} {
a = append(a, b...)
return a
}

View File

@ -0,0 +1,45 @@
package reflect
import (
"net/url"
"testing"
)
func TestURLSliceVars(t *testing.T) {
u, err := url.Parse("http://localhost/v1/test/call/my_name?key=arg1&key=arg2&key=arg3")
if err != nil {
t.Fatal(err)
}
mp, err := URLMap(u.RawQuery)
if err != nil {
t.Fatal(err)
}
v, ok := mp["key"]
if !ok {
t.Fatalf("key not exists: %#+v", mp)
}
vm, ok := v.([]interface{})
if !ok {
t.Fatalf("invalid key value")
}
if len(vm) != 3 {
t.Fatalf("missing key value: %#+v", mp)
}
}
func TestURLVars(t *testing.T) {
u, err := url.Parse("http://localhost/v1/test/call/my_name?req=key&arg1=arg1&arg2=12345&nested.string_args=str1&nested.string_args=str2&arg2=54321")
if err != nil {
t.Fatal(err)
}
mp, err := URLMap(u.RawQuery)
if err != nil {
t.Fatal(err)
}
_ = mp
}