diff --git a/codec/codec.go b/codec/codec.go index 81b2ee1e..051fa54c 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -28,6 +28,8 @@ var ( DefaultMaxMsgSize int = 1024 * 1024 * 4 // 4Mb // DefaultCodec is the global default codec DefaultCodec Codec = NewCodec() + // DefaultTagName specifies struct tag name to control codec Marshal/Unmarshal + DefaultTagName = "codec" ) // MessageType specifies message type for codec diff --git a/util/reflect/reflect.go b/util/reflect/reflect.go index c76386f4..06a686d4 100644 --- a/util/reflect/reflect.go +++ b/util/reflect/reflect.go @@ -12,6 +12,7 @@ import ( var ( ErrInvalidStruct = errors.New("invalid struct specified") ErrInvalidValue = errors.New("invalid value specified") + ErrNotFound = errors.New("struct field not found") ) type Option func(*Options) diff --git a/util/reflect/struct.go b/util/reflect/struct.go index 06082db3..1db52034 100644 --- a/util/reflect/struct.go +++ b/util/reflect/struct.go @@ -14,6 +14,84 @@ var ErrInvalidParam = errors.New("invalid url query param provided") var bracketSplitter = regexp.MustCompile(`\[|\]`) +func StructFieldByTag(src interface{}, tkey string, tval string) (interface{}, error) { + sv := reflect.ValueOf(src) + if sv.Kind() == reflect.Ptr { + sv = sv.Elem() + } + if sv.Kind() != reflect.Struct { + return nil, ErrInvalidStruct + } + + typ := sv.Type() + for idx := 0; idx < typ.NumField(); idx++ { + fld := typ.Field(idx) + val := sv.Field(idx) + if !val.CanSet() || len(fld.PkgPath) != 0 { + continue + } + switch val.Kind() { + default: + ts, ok := fld.Tag.Lookup(tkey) + if !ok { + continue + } + for _, p := range strings.Split(ts, ",") { + if p == tval { + return val.Interface(), nil + } + } + case reflect.Ptr: + if val = val.Elem(); val.Kind() == reflect.Struct { + if iface, err := StructFieldByTag(val.Interface(), tkey, tval); err == nil { + return iface, nil + } + } + case reflect.Struct: + if iface, err := StructFieldByTag(val.Interface(), tkey, tval); err == nil { + return iface, nil + } + } + } + return nil, ErrNotFound +} + +func StructFieldByName(src interface{}, tkey string) (interface{}, error) { + sv := reflect.ValueOf(src) + if sv.Kind() == reflect.Ptr { + sv = sv.Elem() + } + if sv.Kind() != reflect.Struct { + return nil, ErrInvalidStruct + } + + typ := sv.Type() + for idx := 0; idx < typ.NumField(); idx++ { + fld := typ.Field(idx) + val := sv.Field(idx) + if !val.CanSet() || len(fld.PkgPath) != 0 { + continue + } + switch val.Kind() { + default: + if fld.Name == tkey { + return val.Interface(), nil + } + case reflect.Ptr: + if val = val.Elem(); val.Kind() == reflect.Struct { + if iface, err := StructFieldByName(val.Interface(), tkey); err == nil { + return iface, nil + } + } + case reflect.Struct: + if iface, err := StructFieldByName(val.Interface(), tkey); err == nil { + return iface, nil + } + } + } + return nil, ErrNotFound +} + // StructFields returns slice of struct fields func StructFields(src interface{}) ([]reflect.StructField, error) { var fields []reflect.StructField diff --git a/util/reflect/struct_test.go b/util/reflect/struct_test.go index 4c030a78..ab65ef1c 100644 --- a/util/reflect/struct_test.go +++ b/util/reflect/struct_test.go @@ -5,6 +5,44 @@ import ( "testing" ) +func TestStructByTag(t *testing.T) { + type Str struct { + Name []string `json:"name" codec:"flatten"` + } + + val := &Str{Name: []string{"first", "second"}} + + iface, err := StructFieldByTag(val, "codec", "flatten") + if err != nil { + t.Fatal(err) + } + + if v, ok := iface.([]string); !ok { + t.Fatalf("not []string %v", iface) + } else if len(v) != 2 { + t.Fatalf("invalid number %v", iface) + } +} + +func TestStructByName(t *testing.T) { + type Str struct { + Name []string `json:"name" codec:"flatten"` + } + + val := &Str{Name: []string{"first", "second"}} + + iface, err := StructFieldByName(val, "Name") + if err != nil { + t.Fatal(err) + } + + if v, ok := iface.([]string); !ok { + t.Fatalf("not []string %v", iface) + } else if len(v) != 2 { + t.Fatalf("invalid number %v", iface) + } +} + func TestStructURLValues(t *testing.T) { type Str struct { Str *Str `json:"str"`