diff --git a/flag.go b/flag.go index bfacf04..9e37f79 100644 --- a/flag.go +++ b/flag.go @@ -28,6 +28,7 @@ var ( */ type flagConfig struct { + fset *flag.FlagSet opts config.Options } @@ -45,7 +46,7 @@ func (c *flagConfig) Init(opts ...config.Option) error { return err } - flag.CommandLine.Init(os.Args[0], flag.ContinueOnError) + // flag.CommandLine.Init(os.Args[0], flag.ContinueOnError) for _, sf := range fields { tf, ok := sf.Field.Tag.Lookup(c.opts.StructTag) if !ok { @@ -65,6 +66,8 @@ func (c *flagConfig) Init(opts ...config.Option) error { if f := flag.Lookup(fn); f != nil { return nil } + + fmt.Printf("register %s flag\n", fn) switch vi.(type) { case time.Duration: err = c.flagDuration(sf.Value, fn, fv, fd) @@ -151,5 +154,37 @@ func NewConfig(opts ...config.Option) config.Config { if len(options.StructTag) == 0 { options.StructTag = DefaultStructTag } - return &flagConfig{opts: options} + flagSet := flag.CommandLine + flagSetName := os.Args[0] + flagSetErrorHandling := flag.ExitOnError + var flagUsage func() + var isSet bool + + if options.Context != nil { + if v, ok := options.Context.Value(flagSetNameKey{}).(string); ok { + isSet = true + flagSetName = v + } + if v, ok := options.Context.Value(flagSetErrorHandlingKey{}).(flag.ErrorHandling); ok { + isSet = true + flagSetErrorHandling = v + } + if v, ok := options.Context.Value(flagSetKey{}).(*flag.FlagSet); ok { + flagSet = v + } + if v, ok := options.Context.Value(flagSetUsageKey{}).(func()); ok { + flagUsage = v + } + } + + if isSet { + flagSet.Init(flagSetName, flagSetErrorHandling) + } + if flagUsage != nil { + flagSet.Usage = flagUsage + } + + c := &flagConfig{opts: options, fset: flagSet} + + return c } diff --git a/flag_test.go b/flag_test.go index 94b9eb2..bc12bcb 100644 --- a/flag_test.go +++ b/flag_test.go @@ -2,6 +2,7 @@ package flag import ( "context" + "flag" "os" "testing" "time" @@ -40,7 +41,7 @@ func TestLoad(t *testing.T) { ctx := context.Background() cfg := &Config{Nested: &NestedConfig{}} - c := NewConfig(config.Struct(cfg), TimeFormat(time.RFC822)) + c := NewConfig(config.Struct(cfg), TimeFormat(time.RFC822), FlagErrorHandling(flag.ContinueOnError)) if err := c.Init(); err != nil { t.Fatalf("init failed: %v", err) } diff --git a/options.go b/options.go index ae904dc..02a84e3 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,8 @@ package flag import ( + "flag" + "go.unistack.org/micro/v3/config" ) @@ -24,3 +26,31 @@ type timeFormatKey struct{} func TimeFormat(s string) config.Option { return config.SetOption(timeFormatKey{}, s) } + +type flagSetKey struct{} + +// FlagSet set flag set name +func FlagSet(f *flag.FlagSet) config.Option { + return config.SetOption(flagSetKey{}, f) +} + +type flagSetNameKey struct{} + +// FlagSetName set flag set name +func FlagSetName(n string) config.Option { + return config.SetOption(flagSetNameKey{}, n) +} + +type flagSetErrorHandlingKey struct{} + +// FlagErrorHandling set flag set error handling +func FlagErrorHandling(eh flag.ErrorHandling) config.Option { + return config.SetOption(flagSetErrorHandlingKey{}, eh) +} + +type flagSetUsageKey struct{} + +// FlagUsage set flag set usage func +func FlagUsage(fn func()) config.Option { + return config.SetOption(flagSetUsageKey{}, fn) +}