diff --git a/config/config.go b/config/config.go index d2b5a9a..6dedc2b 100644 --- a/config/config.go +++ b/config/config.go @@ -90,10 +90,20 @@ func IsZero(c interface{}) bool { return isZero(reflect.ValueOf(c)) } -// AssertValid checks the fields in the structure and makes sure that they -// contain valid values as specified by the 'valid' flag. Empty fields are +type ErrorValid struct { + Value string + Valid []string + Field string +} + +func (e ErrorValid) Error() string { + return fmt.Sprintf("invalid value %q for option %q (valid options: %q)", e.Value, e.Field, e.Valid) +} + +// AssertStructValid checks the fields in the structure and makes sure that +// they contain valid values as specified by the 'valid' flag. Empty fields are // implicitly valid. -func AssertValid(c interface{}) error { +func AssertStructValid(c interface{}) error { ct := reflect.TypeOf(c) cv := reflect.ValueOf(c) for i := 0; i < ct.NumField(); i++ { @@ -102,15 +112,33 @@ func AssertValid(c interface{}) error { continue } - valid := ft.Tag.Get("valid") - val := cv.Field(i) - if !isValid(val, valid) { - return fmt.Errorf("invalid value \"%v\" for option %q (valid options: %q)", val.Interface(), ft.Name, valid) + if err := AssertValid(cv.Field(i), ft.Tag.Get("valid")); err != nil { + err.Field = ft.Name + return err } } return nil } +// AssertValid checks to make sure that the given value is in the list of +// valid values. Zero values are implicitly valid. +func AssertValid(value reflect.Value, valid string) *ErrorValid { + if valid == "" || isZero(value) { + return nil + } + vs := fmt.Sprintf("%v", value.Interface()) + valids := strings.Split(valid, ",") + for _, valid := range valids { + if vs == valid { + return nil + } + } + return &ErrorValid{ + Value: vs, + Valid: valids, + } +} + func isZero(v reflect.Value) bool { switch v.Kind() { case reflect.Struct: @@ -130,19 +158,6 @@ func isFieldExported(f reflect.StructField) bool { return f.PkgPath == "" } -func isValid(v reflect.Value, valid string) bool { - if valid == "" || isZero(v) { - return true - } - vs := fmt.Sprintf("%v", v.Interface()) - for _, valid := range strings.Split(valid, ",") { - if vs == valid { - return true - } - } - return false -} - type warner func(format string, v ...interface{}) // warnOnUnrecognizedKeys parses the contents of a cloud-config file and calls diff --git a/config/config_test.go b/config/config_test.go index 1d776d6..ea68947 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -17,7 +17,6 @@ package config import ( - "errors" "fmt" "reflect" "strings" @@ -43,7 +42,7 @@ func TestIsZero(t *testing.T) { } } -func TestAssertValid(t *testing.T) { +func TestAssertStructValid(t *testing.T) { for _, tt := range []struct { c interface{} err error @@ -60,7 +59,7 @@ func TestAssertValid(t *testing.T) { }{A: "1", b: "hello"}, nil}, {struct { A, b string `valid:"1,2"` - }{A: "hello", b: "2"}, errors.New("invalid value \"hello\" for option \"A\" (valid options: \"1,2\")")}, + }{A: "hello", b: "2"}, &ErrorValid{Value: "hello", Field: "A", Valid: []string{"1", "2"}}}, {struct { A, b int `valid:"1,2"` }{}, nil}, @@ -72,9 +71,9 @@ func TestAssertValid(t *testing.T) { }{A: 1, b: 9}, nil}, {struct { A, b int `valid:"1,2"` - }{A: 9, b: 2}, errors.New("invalid value \"9\" for option \"A\" (valid options: \"1,2\")")}, + }{A: 9, b: 2}, &ErrorValid{Value: "9", Field: "A", Valid: []string{"1", "2"}}}, } { - if err := AssertValid(tt.c); !reflect.DeepEqual(tt.err, err) { + if err := AssertStructValid(tt.c); !reflect.DeepEqual(tt.err, err) { t.Errorf("bad result (%q): want %q, got %q", tt.c, tt.err, err) } } diff --git a/system/update.go b/system/update.go index 2038def..5d6b066 100644 --- a/system/update.go +++ b/system/update.go @@ -61,7 +61,7 @@ func (uc Update) File() (*File, error) { if config.IsZero(uc.Update) { return nil, nil } - if err := config.AssertValid(uc.Update); err != nil { + if err := config.AssertStructValid(uc.Update); err != nil { return nil, err } diff --git a/system/update_test.go b/system/update_test.go index a75348e..7b68172 100644 --- a/system/update_test.go +++ b/system/update_test.go @@ -17,7 +17,6 @@ package system import ( - "errors" "io" "reflect" "strings" @@ -101,7 +100,7 @@ func TestUpdateFile(t *testing.T) { }, { config: config.Update{RebootStrategy: "wizzlewazzle"}, - err: errors.New("invalid value \"wizzlewazzle\" for option \"RebootStrategy\" (valid options: \"best-effort,etcd-lock,reboot,false\")"), + err: &config.ErrorValid{Value: "wizzlewazzle", Field: "RebootStrategy", Valid: []string{"best-effort", "etcd-lock", "reboot", "false"}}, }, { config: config.Update{Group: "master", Server: "http://foo.com"},