diff --git a/util/time/duration.go b/util/time/duration.go index 36ba97a1..49550e5e 100644 --- a/util/time/duration.go +++ b/util/time/duration.go @@ -6,6 +6,8 @@ import ( "fmt" "strconv" "time" + + "gopkg.in/yaml.v3" ) type Duration int64 @@ -53,6 +55,31 @@ loop: return time.ParseDuration(fmt.Sprintf("%dh%s", hours, s[p:])) } +func (d Duration) MarshalYAML() (interface{}, error) { + return time.Duration(d).String(), nil +} + +func (d *Duration) UnmarshalYAML(n *yaml.Node) error { + var v interface{} + if err := yaml.Unmarshal([]byte(n.Value), &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + *d = Duration(time.Duration(value)) + return nil + case string: + dv, err := ParseDuration(value) + if err != nil { + return err + } + *d = Duration(dv) + return nil + default: + return fmt.Errorf("invalid duration") + } +} + func (d Duration) MarshalJSON() ([]byte, error) { return json.Marshal(time.Duration(d).String()) } diff --git a/util/time/duration_test.go b/util/time/duration_test.go index 12447510..0324e68a 100644 --- a/util/time/duration_test.go +++ b/util/time/duration_test.go @@ -5,8 +5,44 @@ import ( "encoding/json" "testing" "time" + + "gopkg.in/yaml.v3" ) +func TestMarshalYAML(t *testing.T) { + d := Duration(10000000) + buf, err := yaml.Marshal(d) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, []byte(`10ms +`)) { + t.Fatalf("invalid duration: %s != %s", buf, `10ms`) + } +} + +func TestUnmarshalYAML(t *testing.T) { + type str struct { + TTL Duration `yaml:"ttl"` + } + v := &str{} + var err error + + err = yaml.Unmarshal([]byte(`{"ttl":"10ms"}`), v) + if err != nil { + t.Fatal(err) + } else if v.TTL != 10000000 { + t.Fatalf("invalid duration %v != 10000000", v.TTL) + } + + err = yaml.Unmarshal([]byte(`{"ttl":"1y"}`), v) + if err != nil { + t.Fatal(err) + } else if v.TTL != 31622400000000000 { + t.Fatalf("invalid duration %v != 31622400000000000", v.TTL) + } +} + func TestMarshalJSON(t *testing.T) { d := Duration(10000000) buf, err := json.Marshal(d)