userdata: gzip autodetection

look for the gzip magic number at the beginning of a data source, and,
if found, decompress the input before further processing.

Addresses coreos/bugs#741.
This commit is contained in:
Seth Pellegrino 2015-09-20 18:11:26 -07:00
parent 86909e5bcb
commit 778a47b957
2 changed files with 80 additions and 0 deletions

View File

@ -15,8 +15,11 @@
package main
import (
"bytes"
"compress/gzip"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"sync"
@ -178,6 +181,11 @@ func main() {
log.Printf("Failed fetching user-data from datasource: %v. Continuing...\n", err)
failure = true
}
userdataBytes, err = decompressIfGzip(userdataBytes)
if err != nil {
log.Printf("Failed decompressing user-data from datasource: %v. Continuing...\n", err)
failure = true
}
if report, err := validate.Validate(userdataBytes); err == nil {
ret := 0
@ -392,3 +400,17 @@ func runScript(script config.Script, env *initialize.Environment) error {
}
return err
}
const gzipMagicBytes = "\x1f\x8b"
func decompressIfGzip(userdataBytes []byte) ([]byte, error) {
if !bytes.HasPrefix(userdataBytes, []byte(gzipMagicBytes)) {
return userdataBytes, nil
}
gzr, err := gzip.NewReader(bytes.NewReader(userdataBytes))
if err != nil {
return nil, err
}
defer gzr.Close()
return ioutil.ReadAll(gzr)
}

View File

@ -15,6 +15,9 @@
package main
import (
"bytes"
"encoding/base64"
"errors"
"reflect"
"testing"
@ -87,3 +90,58 @@ func TestMergeConfigs(t *testing.T) {
}
}
}
func mustDecode(in string) []byte {
out, err := base64.StdEncoding.DecodeString(in)
if err != nil {
panic(err)
}
return out
}
func TestDecompressIfGzip(t *testing.T) {
tests := []struct {
in []byte
out []byte
err error
}{
{
in: nil,
out: nil,
err: nil,
},
{
in: []byte{},
out: []byte{},
err: nil,
},
{
in: mustDecode("H4sIAJWV/VUAA1NOzskvTdFNzs9Ly0wHABt6mQENAAAA"),
out: []byte("#cloud-config"),
err: nil,
},
{
in: []byte("#cloud-config"),
out: []byte("#cloud-config"),
err: nil,
},
{
in: mustDecode("H4sCORRUPT=="),
out: nil,
err: errors.New("any error"),
},
}
for i, tt := range tests {
out, err := decompressIfGzip(tt.in)
if !bytes.Equal(out, tt.out) || (tt.err != nil && err == nil) {
t.Errorf("bad gzip (%d): want (%s, %#v), got (%s, %#v)", i, string(tt.out), tt.err, string(out), err)
}
}
}