diff --git a/coreos-cloudinit.go b/coreos-cloudinit.go index 909e8b8..d82a72d 100644 --- a/coreos-cloudinit.go +++ b/coreos-cloudinit.go @@ -15,8 +15,11 @@ package main import ( + "bytes" + "compress/gzip" "flag" "fmt" + "io/ioutil" "log" "os" "runtime" @@ -185,6 +188,11 @@ func main() { log.Printf("Failed fetching user-data from datasource: %v. Continuing...\n", err) failure = true } + userdataBytes, err = decompressIfGzip(userdataBytes) + if err != nil { + log.Printf("Failed decompressing user-data from datasource: %v. Continuing...\n", err) + failure = true + } if report, err := validate.Validate(userdataBytes); err == nil { ret := 0 @@ -399,3 +407,17 @@ func runScript(script config.Script, env *initialize.Environment) error { } return err } + +const gzipMagicBytes = "\x1f\x8b" + +func decompressIfGzip(userdataBytes []byte) ([]byte, error) { + if !bytes.HasPrefix(userdataBytes, []byte(gzipMagicBytes)) { + return userdataBytes, nil + } + gzr, err := gzip.NewReader(bytes.NewReader(userdataBytes)) + if err != nil { + return nil, err + } + defer gzr.Close() + return ioutil.ReadAll(gzr) +} diff --git a/coreos-cloudinit_test.go b/coreos-cloudinit_test.go index cd87d5f..cd5c6da 100644 --- a/coreos-cloudinit_test.go +++ b/coreos-cloudinit_test.go @@ -15,6 +15,9 @@ package main import ( + "bytes" + "encoding/base64" + "errors" "reflect" "testing" @@ -87,3 +90,58 @@ func TestMergeConfigs(t *testing.T) { } } } + +func mustDecode(in string) []byte { + out, err := base64.StdEncoding.DecodeString(in) + if err != nil { + panic(err) + } + return out +} + +func TestDecompressIfGzip(t *testing.T) { + tests := []struct { + in []byte + + out []byte + err error + }{ + { + in: nil, + + out: nil, + err: nil, + }, + { + in: []byte{}, + + out: []byte{}, + err: nil, + }, + { + in: mustDecode("H4sIAJWV/VUAA1NOzskvTdFNzs9Ly0wHABt6mQENAAAA"), + + out: []byte("#cloud-config"), + err: nil, + }, + { + in: []byte("#cloud-config"), + + out: []byte("#cloud-config"), + err: nil, + }, + { + in: mustDecode("H4sCORRUPT=="), + + out: nil, + err: errors.New("any error"), + }, + } + for i, tt := range tests { + out, err := decompressIfGzip(tt.in) + if !bytes.Equal(out, tt.out) || (tt.err != nil && err == nil) { + t.Errorf("bad gzip (%d): want (%s, %#v), got (%s, %#v)", i, string(tt.out), tt.err, string(out), err) + } + } + +} diff --git a/test b/test index 95daad5..645e967 100755 --- a/test +++ b/test @@ -21,6 +21,7 @@ SRC=" network pkg system + . " echo "Checking gofix..."