From 7a2461f7ce1b532e89466acb3c92d873fedcc4e4 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Wed, 23 Mar 2022 00:57:59 +0300 Subject: [PATCH] rework map and slice support Signed-off-by: Vasiliy Tolstov --- flag.go | 10 +- flag_test.go | 30 +++-- util.go | 307 +++++++++++++++++++++++++++++++++------------------ 3 files changed, 227 insertions(+), 120 deletions(-) diff --git a/flag.go b/flag.go index 905ae83..6448676 100644 --- a/flag.go +++ b/flag.go @@ -53,7 +53,15 @@ func (c *flagConfig) Init(opts ...config.Option) error { rcheck := true - switch sf.Value.Interface().(type) { + if !sf.Value.IsValid() { + continue + } + vi := sf.Value.Interface() + if vi == nil { + continue + } + + switch vi.(type) { case time.Duration: err = c.flagDuration(sf.Value, fn, fv, fd) rcheck = false diff --git a/flag_test.go b/flag_test.go index 0ae6ab3..1854744 100644 --- a/flag_test.go +++ b/flag_test.go @@ -17,20 +17,26 @@ func TestLoad(t *testing.T) { os.Args = append(os.Args, "-addr", "33,44") os.Args = append(os.Args, "-time", time.RFC822) os.Args = append(os.Args, "-metadata", "key=20") + os.Args = append(os.Args, "-components", "all=info,api=debug") + type NestedConfig struct { + Value string `flag:"name=nested_value"` + } type Config struct { - Broker string `flag:"name=broker,desc='description with, comma',default='127.0.0.1:9092'"` - Verbose bool `flag:"name=verbose,desc='verbose output',default='false'"` - Addr []string `flag:"name=addr,desc='addrs',default='127.0.0.1:9092'"` - Wait time.Duration `flag:"name=wait,desc='wait time',default='2s'"` - Time time.Time `flag:"name=time,desc='some time',default='02 Jan 06 15:04 MST'"` - Metadata map[string]int `flag:"name=metadata,desc='some meta',default=''"` - WithoutDefault string `flag:"name=without_default,desc='with'"` - WithoutDesc string `flag:"name=without_desc,default='without_default'"` - WithoutAll string `flag:"name=without_all"` + Broker string `flag:"name=broker,desc='description with, comma',default='127.0.0.1:9092'"` + Verbose bool `flag:"name=verbose,desc='verbose output',default='false'"` + Addr []string `flag:"name=addr,desc='addrs',default='127.0.0.1:9092'"` + Wait time.Duration `flag:"name=wait,desc='wait time',default='2s'"` + Time time.Time `flag:"name=time,desc='some time',default='02 Jan 06 15:04 MST'"` + Metadata map[string]int `flag:"name=metadata,desc='some meta',default=''"` + WithoutDefault string `flag:"name=without_default,desc='with'"` + WithoutDesc string `flag:"name=without_desc,default='without_default'"` + WithoutAll string `flag:"name=without_all"` + Components map[string]string `flag:"name=components,desc='components logging'"` + Nested *NestedConfig } ctx := context.Background() - cfg := &Config{} + cfg := &Config{Nested: &NestedConfig{}} c := NewConfig(config.Struct(cfg), TimeFormat(time.RFC822)) if err := c.Init(); err != nil { @@ -47,4 +53,8 @@ func TestLoad(t *testing.T) { if tf := cfg.Time.Format(time.RFC822); tf != "02 Jan 06 15:04 MST" { t.Fatalf("parse time error: %s != %s", tf, "02 Jan 06 15:04 MST") } + + if len(cfg.Components) != 2 { + t.Fatalf("cant parse map components %#+v", cfg) + } } diff --git a/util.go b/util.go index 413eafa..1ff1bf7 100644 --- a/util.go +++ b/util.go @@ -2,12 +2,131 @@ package flag import ( "flag" + "fmt" "reflect" "strconv" "strings" "time" ) +type mapValue struct { + v reflect.Value + delim string + def string +} + +func (v mapValue) String() string { + if v.v.Kind() != reflect.Invalid { + var kv []string + it := v.v.MapRange() + for it.Next() { + k := it.Key().Interface() + v := it.Value().Interface() + kv = append(kv, fmt.Sprintf("%v=%v", k, v)) + } + return strings.Join(kv, ",") + } + return v.def +} + +func (v mapValue) Set(s string) error { + ps := strings.Split(s, v.delim) + if len(ps) == 0 { + return nil + } + v.v.Set(reflect.MakeMapWithSize(v.v.Type(), len(ps))) + kt := v.v.Type().Key().Kind() + vt := v.v.Type().Elem().Kind() + + for i := 0; i < len(ps); i++ { + fs := strings.Split(ps[i], "=") + switch len(fs) { + case 0: + return nil + case 1: + if len(fs[0]) == 0 { + return nil + } + return ErrInvalidValue + case 2: + break + default: + return ErrInvalidValue + } + key, err := convertType(reflect.ValueOf(fs[0]), kt) + if err != nil { + return err + } + val, err := convertType(reflect.ValueOf(fs[1]), vt) + if err != nil { + return err + } + v.v.SetMapIndex(key.Convert(v.v.Type().Key()), val.Convert(v.v.Type().Elem())) + } + return nil +} + +type sliceValue struct { + v reflect.Value + delim string + def string +} + +func (v sliceValue) String() string { + if v.v.Kind() != reflect.Invalid { + var kv []string + for idx := 0; idx < v.v.Len(); idx++ { + kv = append(kv, fmt.Sprintf("%v", v.v.Index(idx).Interface())) + } + return strings.Join(kv, ",") + } + return v.def +} + +func (v sliceValue) Set(s string) error { + p := strings.Split(s, v.delim) + v.v.Set(reflect.MakeSlice(v.v.Type(), len(p), len(p))) + switch v.v.Type().Elem().Kind() { + case reflect.Int, reflect.Int64: + for idx := range p { + i, err := strconv.ParseInt(p[idx], 10, 64) + if err != nil { + return err + } + v.v.Index(idx).SetInt(i) + } + case reflect.Uint, reflect.Uint64: + for idx := range p { + i, err := strconv.ParseUint(p[idx], 10, 64) + if err != nil { + return err + } + v.v.Index(idx).SetUint(i) + } + case reflect.Float64: + for idx := range p { + i, err := strconv.ParseFloat(p[idx], 64) + if err != nil { + return err + } + v.v.Index(idx).SetFloat(i) + } + case reflect.Bool: + for idx := range p { + i, err := strconv.ParseBool(p[idx]) + if err != nil { + return err + } + v.v.Index(idx).SetBool(i) + } + case reflect.String: + for idx := range p { + v.v.Index(idx).SetString(p[idx]) + } + } + return nil +} + func convertType(v reflect.Value, t reflect.Kind) (reflect.Value, error) { switch v.Kind() { case reflect.String: @@ -51,50 +170,12 @@ func (c *flagConfig) flagSlice(v reflect.Value, fn, fv, fd string) error { } } - flag.Func(fn, fd, func(s string) error { - p := strings.Split(s, delim) - v.Set(reflect.MakeSlice(v.Type(), len(p), len(p))) - switch v.Type().Elem().Kind() { - case reflect.Int, reflect.Int64: - for idx := range p { - i, err := strconv.ParseInt(p[idx], 10, 64) - if err != nil { - return err - } - v.Index(idx).SetInt(i) - } - case reflect.Uint, reflect.Uint64: - for idx := range p { - i, err := strconv.ParseUint(p[idx], 10, 64) - if err != nil { - return err - } - v.Index(idx).SetUint(i) - } - case reflect.Float64: - for idx := range p { - i, err := strconv.ParseFloat(p[idx], 64) - if err != nil { - return err - } - v.Index(idx).SetFloat(i) - } - case reflect.Bool: - for idx := range p { - i, err := strconv.ParseBool(p[idx]) - if err != nil { - return err - } - v.Index(idx).SetBool(i) - } - case reflect.String: - for idx := range p { - v.Index(idx).SetString(p[idx]) - } - } - return nil - }) - + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + mp := &sliceValue{v: v, def: fv, delim: delim} + if err := mp.Set(fv); err != nil { + return err + } + flag.Var(mp, fn, fd) return nil } @@ -105,43 +186,12 @@ func (c *flagConfig) flagMap(v reflect.Value, fn, fv, fd string) error { delim = d } } - flag.Func(fn, fv, func(s string) error { - ps := strings.Split(s, delim) - if len(ps) == 0 { - return nil - } - v.Set(reflect.MakeMapWithSize(v.Type(), len(ps))) - kt := v.Type().Key().Kind() - vt := v.Type().Elem().Kind() - - for i := 0; i < len(ps); i++ { - fs := strings.Split(ps[i], "=") - switch len(fs) { - case 0: - return nil - case 1: - if len(fs[0]) == 0 { - return nil - } - return ErrInvalidValue - case 2: - break - default: - return ErrInvalidValue - } - key, err := convertType(reflect.ValueOf(fs[0]), kt) - if err != nil { - return err - } - val, err := convertType(reflect.ValueOf(fs[1]), vt) - if err != nil { - return err - } - v.SetMapIndex(key.Convert(v.Type().Key()), val.Convert(v.Type().Elem())) - } - return nil - }) - + v.Set(reflect.MakeMapWithSize(v.Type(), 0)) + mp := &mapValue{v: v, def: fv, delim: delim} + if err := mp.Set(fv); err != nil { + return err + } + flag.Var(mp, fn, fd) return nil } @@ -277,35 +327,74 @@ func (c *flagConfig) flagStringSlice(v reflect.Value, fn, fv, fd string) error { } func getFlagOpts(tf string) (string, string, string) { - ret := make([]string, 3) - vals := strings.Split(tf, ",") - f := 0 - for _, val := range vals { - p := strings.Split(val, "=") - switch p[0] { - case "name": - f = 0 - case "desc": - f = 1 - case "default": - f = 2 - default: - ret[f] += "," + val - continue - } - ret[f] = p[1] - } + var name, desc, def string + delim := "," - for idx := range ret { - if len(ret[idx]) == 0 { - continue - } - if ret[idx][0] == '\'' { - ret[idx] = ret[idx][1:] - } - if ret[idx][len(ret[idx])-1] == '\'' { - ret[idx] = ret[idx][:len(ret[idx])-1] + var buf string + for idx := 0; idx < len(tf); idx++ { + buf += string(tf[idx]) + switch buf { + case "name": + ndx := idx + 2 + stop := "," + var quote bool + for ; ndx < len(tf); ndx++ { + if string(tf[ndx]) == stop { + if quote { + ndx++ + } + break + } + if string(tf[ndx]) == "'" { + stop = "'" + quote = true + continue + } + name += string(tf[ndx]) + } + idx = ndx + buf = "" + case "desc": + ndx := idx + 2 + stop := "," + var quote bool + for ; ndx < len(tf); ndx++ { + if string(tf[ndx]) == stop { + if quote { + ndx++ + } + break + } + if string(tf[ndx]) == "'" { + stop = "'" + quote = true + continue + } + desc += string(tf[ndx]) + } + idx = ndx + buf = "" + case "default": + ndx := idx + 2 + stop := "," + var quote bool + for ; ndx < len(tf); ndx++ { + if string(tf[ndx]) == stop && (stop != delim) { + if quote { + ndx++ + } + break + } + if string(tf[ndx]) == "'" { + stop = "'" + quote = true + continue + } + def += string(tf[ndx]) + } + idx = ndx + buf = "" } } - return ret[0], ret[1], ret[2] + return name, desc, def }