config: seperate AssertValid and AssertStructValid

Added an error structure to make it possible to get the specifics of the failure.
This commit is contained in:
Alex Crawford 2014-10-28 20:04:43 -07:00
parent 6e2db882e6
commit 88e8265cd6
4 changed files with 41 additions and 28 deletions

View File

@ -90,10 +90,20 @@ func IsZero(c interface{}) bool {
return isZero(reflect.ValueOf(c)) return isZero(reflect.ValueOf(c))
} }
// AssertValid checks the fields in the structure and makes sure that they type ErrorValid struct {
// contain valid values as specified by the 'valid' flag. Empty fields are Value string
Valid []string
Field string
}
func (e ErrorValid) Error() string {
return fmt.Sprintf("invalid value %q for option %q (valid options: %q)", e.Value, e.Field, e.Valid)
}
// AssertStructValid checks the fields in the structure and makes sure that
// they contain valid values as specified by the 'valid' flag. Empty fields are
// implicitly valid. // implicitly valid.
func AssertValid(c interface{}) error { func AssertStructValid(c interface{}) error {
ct := reflect.TypeOf(c) ct := reflect.TypeOf(c)
cv := reflect.ValueOf(c) cv := reflect.ValueOf(c)
for i := 0; i < ct.NumField(); i++ { for i := 0; i < ct.NumField(); i++ {
@ -102,15 +112,33 @@ func AssertValid(c interface{}) error {
continue continue
} }
valid := ft.Tag.Get("valid") if err := AssertValid(cv.Field(i), ft.Tag.Get("valid")); err != nil {
val := cv.Field(i) err.Field = ft.Name
if !isValid(val, valid) { return err
return fmt.Errorf("invalid value \"%v\" for option %q (valid options: %q)", val.Interface(), ft.Name, valid)
} }
} }
return nil return nil
} }
// AssertValid checks to make sure that the given value is in the list of
// valid values. Zero values are implicitly valid.
func AssertValid(value reflect.Value, valid string) *ErrorValid {
if valid == "" || isZero(value) {
return nil
}
vs := fmt.Sprintf("%v", value.Interface())
valids := strings.Split(valid, ",")
for _, valid := range valids {
if vs == valid {
return nil
}
}
return &ErrorValid{
Value: vs,
Valid: valids,
}
}
func isZero(v reflect.Value) bool { func isZero(v reflect.Value) bool {
switch v.Kind() { switch v.Kind() {
case reflect.Struct: case reflect.Struct:
@ -130,19 +158,6 @@ func isFieldExported(f reflect.StructField) bool {
return f.PkgPath == "" return f.PkgPath == ""
} }
func isValid(v reflect.Value, valid string) bool {
if valid == "" || isZero(v) {
return true
}
vs := fmt.Sprintf("%v", v.Interface())
for _, valid := range strings.Split(valid, ",") {
if vs == valid {
return true
}
}
return false
}
type warner func(format string, v ...interface{}) type warner func(format string, v ...interface{})
// warnOnUnrecognizedKeys parses the contents of a cloud-config file and calls // warnOnUnrecognizedKeys parses the contents of a cloud-config file and calls

View File

@ -17,7 +17,6 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -43,7 +42,7 @@ func TestIsZero(t *testing.T) {
} }
} }
func TestAssertValid(t *testing.T) { func TestAssertStructValid(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
c interface{} c interface{}
err error err error
@ -60,7 +59,7 @@ func TestAssertValid(t *testing.T) {
}{A: "1", b: "hello"}, nil}, }{A: "1", b: "hello"}, nil},
{struct { {struct {
A, b string `valid:"1,2"` A, b string `valid:"1,2"`
}{A: "hello", b: "2"}, errors.New("invalid value \"hello\" for option \"A\" (valid options: \"1,2\")")}, }{A: "hello", b: "2"}, &ErrorValid{Value: "hello", Field: "A", Valid: []string{"1", "2"}}},
{struct { {struct {
A, b int `valid:"1,2"` A, b int `valid:"1,2"`
}{}, nil}, }{}, nil},
@ -72,9 +71,9 @@ func TestAssertValid(t *testing.T) {
}{A: 1, b: 9}, nil}, }{A: 1, b: 9}, nil},
{struct { {struct {
A, b int `valid:"1,2"` A, b int `valid:"1,2"`
}{A: 9, b: 2}, errors.New("invalid value \"9\" for option \"A\" (valid options: \"1,2\")")}, }{A: 9, b: 2}, &ErrorValid{Value: "9", Field: "A", Valid: []string{"1", "2"}}},
} { } {
if err := AssertValid(tt.c); !reflect.DeepEqual(tt.err, err) { if err := AssertStructValid(tt.c); !reflect.DeepEqual(tt.err, err) {
t.Errorf("bad result (%q): want %q, got %q", tt.c, tt.err, err) t.Errorf("bad result (%q): want %q, got %q", tt.c, tt.err, err)
} }
} }

View File

@ -61,7 +61,7 @@ func (uc Update) File() (*File, error) {
if config.IsZero(uc.Update) { if config.IsZero(uc.Update) {
return nil, nil return nil, nil
} }
if err := config.AssertValid(uc.Update); err != nil { if err := config.AssertStructValid(uc.Update); err != nil {
return nil, err return nil, err
} }

View File

@ -17,7 +17,6 @@
package system package system
import ( import (
"errors"
"io" "io"
"reflect" "reflect"
"strings" "strings"
@ -101,7 +100,7 @@ func TestUpdateFile(t *testing.T) {
}, },
{ {
config: config.Update{RebootStrategy: "wizzlewazzle"}, config: config.Update{RebootStrategy: "wizzlewazzle"},
err: errors.New("invalid value \"wizzlewazzle\" for option \"RebootStrategy\" (valid options: \"best-effort,etcd-lock,reboot,false\")"), err: &config.ErrorValid{Value: "wizzlewazzle", Field: "RebootStrategy", Valid: []string{"best-effort", "etcd-lock", "reboot", "false"}},
}, },
{ {
config: config.Update{Group: "master", Server: "http://foo.com"}, config: config.Update{Group: "master", Server: "http://foo.com"},