diff --git a/util/reflect/struct.go b/util/reflect/struct.go index 1db52034..a1ac91ca 100644 --- a/util/reflect/struct.go +++ b/util/reflect/struct.go @@ -30,17 +30,19 @@ func StructFieldByTag(src interface{}, tkey string, tval string) (interface{}, e if !val.CanSet() || len(fld.PkgPath) != 0 { continue } - switch val.Kind() { - default: - ts, ok := fld.Tag.Lookup(tkey) - if !ok { - continue - } + + if ts, ok := fld.Tag.Lookup(tkey); ok { for _, p := range strings.Split(ts, ",") { if p == tval { + if val.Kind() != reflect.Ptr && val.CanAddr() { + val = val.Addr() + } return val.Interface(), nil } } + } + + switch val.Kind() { case reflect.Ptr: if val = val.Elem(); val.Kind() == reflect.Struct { if iface, err := StructFieldByTag(val.Interface(), tkey, tval); err == nil { @@ -72,11 +74,14 @@ func StructFieldByName(src interface{}, tkey string) (interface{}, error) { if !val.CanSet() || len(fld.PkgPath) != 0 { continue } - switch val.Kind() { - default: - if fld.Name == tkey { - return val.Interface(), nil + if fld.Name == tkey { + if val.Kind() != reflect.Ptr && val.CanAddr() { + val = val.Addr() } + return val.Interface(), nil + } + + switch val.Kind() { case reflect.Ptr: if val = val.Elem(); val.Kind() == reflect.Struct { if iface, err := StructFieldByName(val.Interface(), tkey); err == nil { diff --git a/util/reflect/struct_test.go b/util/reflect/struct_test.go index ab65ef1c..582f3587 100644 --- a/util/reflect/struct_test.go +++ b/util/reflect/struct_test.go @@ -17,9 +17,9 @@ func TestStructByTag(t *testing.T) { t.Fatal(err) } - if v, ok := iface.([]string); !ok { - t.Fatalf("not []string %v", iface) - } else if len(v) != 2 { + if v, ok := iface.(*[]string); !ok { + t.Fatalf("not *[]string %v", iface) + } else if len(*v) != 2 { t.Fatalf("invalid number %v", iface) } } @@ -36,9 +36,9 @@ func TestStructByName(t *testing.T) { t.Fatal(err) } - if v, ok := iface.([]string); !ok { - t.Fatalf("not []string %v", iface) - } else if len(v) != 2 { + if v, ok := iface.(*[]string); !ok { + t.Fatalf("not *[]string %v", iface) + } else if len(*v) != 2 { t.Fatalf("invalid number %v", iface) } }