diff --git a/Documentation/vmware-backdoor.md b/Documentation/vmware-backdoor.md index 4994238..46fd68b 100644 --- a/Documentation/vmware-backdoor.md +++ b/Documentation/vmware-backdoor.md @@ -19,6 +19,7 @@ are supported by coreos-cloudinit: | `dns.server.` | `IP address` | | `coreos.config.data` | `string` | | `coreos.config.data.encoding` | `{"", "base64", "gzip+base64"}` | +| `coreos.config.url` | `URL` | Note: "n", "m", "l", and "x" are 0-indexed, incrementing integers. The identifier for the interfaces does not correspond to anything outside of this diff --git a/datasource/vmware/vmware.go b/datasource/vmware/vmware.go index 9eb7c05..4fb382a 100644 --- a/datasource/vmware/vmware.go +++ b/datasource/vmware/vmware.go @@ -21,17 +21,25 @@ import ( "github.com/coreos/coreos-cloudinit/config" "github.com/coreos/coreos-cloudinit/datasource" + "github.com/coreos/coreos-cloudinit/pkg" "github.com/coreos/coreos-cloudinit/Godeps/_workspace/src/github.com/sigma/vmw-guestinfo/rpcvmx" "github.com/coreos/coreos-cloudinit/Godeps/_workspace/src/github.com/sigma/vmw-guestinfo/vmcheck" ) +type readConfigFunction func(key string) (string, error) +type urlDownloadFunction func(url string) ([]byte, error) + type vmware struct { - readConfig func(key string) (string, error) + readConfig readConfigFunction + urlDownload urlDownloadFunction } func NewDatasource() *vmware { - return &vmware{readConfig} + return &vmware{ + readConfig: readConfig, + urlDownload: urlDownload, + } } func (v vmware) IsAvailable() bool { @@ -133,6 +141,22 @@ func (v vmware) FetchUserdata() ([]byte, error) { return nil, err } + // Try to fallback to url if no explicit data + if data == "" { + url, err := v.readConfig("coreos.config.url") + if err != nil { + return nil, err + } + + if url != "" { + rawData, err := v.urlDownload(url) + if err != nil { + return nil, err + } + data = string(rawData) + } + } + if encoding != "" { return config.DecodeContent(data, encoding) } @@ -143,6 +167,11 @@ func (v vmware) Type() string { return "vmware" } +func urlDownload(url string) ([]byte, error) { + client := pkg.NewHttpClient() + return client.GetRetry(url) +} + func readConfig(key string) (string, error) { data, err := rpcvmx.NewConfig().GetString(key, "") if err == nil { diff --git a/datasource/vmware/vmware_test.go b/datasource/vmware/vmware_test.go index 8ef5399..19cdd31 100644 --- a/datasource/vmware/vmware_test.go +++ b/datasource/vmware/vmware_test.go @@ -115,7 +115,7 @@ func TestFetchMetadata(t *testing.T) { } for i, tt := range tests { - v := vmware{tt.variables.ReadConfig} + v := vmware{readConfig: tt.variables.ReadConfig} metadata, err := v.FetchMetadata() if !reflect.DeepEqual(tt.err, err) { t.Errorf("bad error (#%d): want %v, got %v", i, tt.err, err) @@ -165,10 +165,37 @@ func TestFetchUserdata(t *testing.T) { }, err: errors.New(`Unsupported encoding "test encoding"`), }, + { + variables: map[string]string{ + "coreos.config.url": "http://good.example.com", + }, + userdata: "test config", + }, + { + variables: map[string]string{ + "coreos.config.url": "http://bad.example.com", + }, + err: errors.New("Not found"), + }, + } + + var downloader urlDownloadFunction = func(url string) ([]byte, error) { + mapping := map[string]struct { + data []byte + err error + }{ + "http://good.example.com": {[]byte("test config"), nil}, + "http://bad.example.com": {nil, errors.New("Not found")}, + } + val := mapping[url] + return val.data, val.err } for i, tt := range tests { - v := vmware{tt.variables.ReadConfig} + v := vmware{ + readConfig: tt.variables.ReadConfig, + urlDownload: downloader, + } userdata, err := v.FetchUserdata() if !reflect.DeepEqual(tt.err, err) { t.Errorf("bad error (#%d): want %v, got %v", i, nil, err) @@ -181,7 +208,7 @@ func TestFetchUserdata(t *testing.T) { func TestFetchUserdataError(t *testing.T) { testErr := errors.New("test error") - _, err := vmware{func(_ string) (string, error) { return "", testErr }}.FetchUserdata() + _, err := vmware{readConfig: func(_ string) (string, error) { return "", testErr }}.FetchUserdata() if testErr != err { t.Errorf("bad error: want %v, got %v", testErr, err)