diff --git a/datasource/metadata_service.go b/datasource/metadata_service.go index 06d2412..956d2b9 100644 --- a/datasource/metadata_service.go +++ b/datasource/metadata_service.go @@ -32,7 +32,7 @@ const ( type metadataService struct{} type getter interface { - Get(string) ([]byte, error) + GetRetry(string) ([]byte, error) } func NewMetadataService() *metadataService { @@ -49,12 +49,12 @@ func (ms *metadataService) FetchMetadata() ([]byte, error) { func (ms *metadataService) FetchUserdata() ([]byte, error) { client := pkg.NewHttpClient() - if data, err := client.Get(Ec2UserdataUrl); err == nil { + if data, err := client.GetRetry(Ec2UserdataUrl); err == nil { return data, err } else if _, ok := err.(pkg.ErrTimeout); ok { return data, err } - return client.Get(OpenstackUserdataUrl) + return client.GetRetry(OpenstackUserdataUrl) } func (ms *metadataService) Type() string { @@ -62,7 +62,7 @@ func (ms *metadataService) Type() string { } func fetchMetadata(client getter) ([]byte, error) { - if metadata, err := client.Get(OpenstackMetadataUrl); err == nil { + if metadata, err := client.GetRetry(OpenstackMetadataUrl); err == nil { return metadata, nil } else if _, ok := err.(pkg.ErrTimeout); ok { return nil, err @@ -76,7 +76,7 @@ func fetchMetadata(client getter) ([]byte, error) { } func fetchAttributes(client getter, url string) ([]string, error) { - resp, err := client.Get(url) + resp, err := client.GetRetry(url) if err != nil { return nil, err } diff --git a/datasource/metadata_service_test.go b/datasource/metadata_service_test.go index 104f97e..58ee298 100644 --- a/datasource/metadata_service_test.go +++ b/datasource/metadata_service_test.go @@ -14,7 +14,7 @@ type TestHttpClient struct { err error } -func (t *TestHttpClient) Get(url string) ([]byte, error) { +func (t *TestHttpClient) GetRetry(url string) ([]byte, error) { if t.err != nil { return nil, t.err } diff --git a/datasource/proc_cmdline.go b/datasource/proc_cmdline.go index ce181ab..b91ea7d 100644 --- a/datasource/proc_cmdline.go +++ b/datasource/proc_cmdline.go @@ -43,7 +43,7 @@ func (c *procCmdline) FetchUserdata() ([]byte, error) { } client := pkg.NewHttpClient() - cfg, err := client.Get(url) + cfg, err := client.GetRetry(url) if err != nil { return nil, err } diff --git a/datasource/url.go b/datasource/url.go index 9fcc788..80f6fb1 100644 --- a/datasource/url.go +++ b/datasource/url.go @@ -20,7 +20,7 @@ func (f *remoteFile) FetchMetadata() ([]byte, error) { func (f *remoteFile) FetchUserdata() ([]byte, error) { client := pkg.NewHttpClient() - return client.Get(f.url) + return client.GetRetry(f.url) } func (f *remoteFile) Type() string { diff --git a/initialize/ssh_keys.go b/initialize/ssh_keys.go index 67870d8..3caf039 100644 --- a/initialize/ssh_keys.go +++ b/initialize/ssh_keys.go @@ -25,7 +25,7 @@ func SSHImportKeysFromURL(system_user string, url string) error { func fetchUserKeys(url string) ([]string, error) { client := pkg.NewHttpClient() - data, err := client.Get(url) + data, err := client.GetRetry(url) if err != nil { return nil, err } diff --git a/pkg/http_client.go b/pkg/http_client.go index 5152fef..095f64a 100644 --- a/pkg/http_client.go +++ b/pkg/http_client.go @@ -20,15 +20,23 @@ const ( type Err error -type ErrTimeout struct{ +type ErrTimeout struct { Err } -type ErrNotFound struct{ +type ErrNotFound struct { Err } -type ErrInvalid struct{ +type ErrInvalid struct { + Err +} + +type ErrServer struct { + Err +} + +type ErrNetwork struct { Err } @@ -45,15 +53,39 @@ type HttpClient struct { // Whether or not to skip TLS verification. Defaults to false SkipTLS bool + + client *http.Client } func NewHttpClient() *HttpClient { - return &HttpClient{ + hc := &HttpClient{ MaxBackoff: time.Second * 5, MaxRetries: 15, Timeout: time.Duration(2) * time.Second, SkipTLS: false, } + + // We need to create our own client in order to add timeout support. + // TODO(c4milo) Replace it once Go 1.3 is officially used by CoreOS + // More info: https://code.google.com/p/go/source/detail?r=ada6f2d5f99f + hc.client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: hc.SkipTLS, + }, + Dial: func(network, addr string) (net.Conn, error) { + deadline := time.Now().Add(hc.Timeout) + c, err := net.DialTimeout(network, addr, hc.Timeout) + if err != nil { + return nil, err + } + c.SetDeadline(deadline) + return c, nil + }, + }, + } + + return hc } func expBackoff(interval, max time.Duration) time.Duration { @@ -64,8 +96,8 @@ func expBackoff(interval, max time.Duration) time.Duration { return interval } -// Fetches a given URL with support for exponential backoff and maximum retries -func (h *HttpClient) Get(rawurl string) ([]byte, error) { +// GetRetry fetches a given URL with support for exponential backoff and maximum retries +func (h *HttpClient) GetRetry(rawurl string) ([]byte, error) { if rawurl == "" { return nil, ErrInvalid{errors.New("URL is empty. Skipping.")} } @@ -83,49 +115,20 @@ func (h *HttpClient) Get(rawurl string) ([]byte, error) { dataURL := url.String() - // We need to create our own client in order to add timeout support. - // TODO(c4milo) Replace it once Go 1.3 is officially used by CoreOS - // More info: https://code.google.com/p/go/source/detail?r=ada6f2d5f99f - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: h.SkipTLS, - }, - Dial: func(network, addr string) (net.Conn, error) { - deadline := time.Now().Add(h.Timeout) - c, err := net.DialTimeout(network, addr, h.Timeout) - if err != nil { - return nil, err - } - c.SetDeadline(deadline) - return c, nil - }, - } - - client := &http.Client{ - Transport: transport, - } - duration := 50 * time.Millisecond for retry := 1; retry <= h.MaxRetries; retry++ { log.Printf("Fetching data from %s. Attempt #%d", dataURL, retry) - resp, err := client.Get(dataURL) - - if err == nil { - defer resp.Body.Close() - status := resp.StatusCode / 100 - - if status == HTTP_2xx { - return ioutil.ReadAll(resp.Body) - } - - if status == HTTP_4xx { - return nil, ErrNotFound{fmt.Errorf("Not found. HTTP status code: %d", resp.StatusCode)} - } - - log.Printf("Server error. HTTP status code: %d", resp.StatusCode) - } else { - log.Printf("Unable to fetch data: %s", err.Error()) + data, err := h.Get(dataURL) + switch err.(type) { + case ErrNetwork: + log.Printf(err.Error()) + case ErrServer: + log.Printf(err.Error()) + case ErrNotFound: + return data, err + default: + return data, err } duration = expBackoff(duration, h.MaxBackoff) @@ -135,3 +138,19 @@ func (h *HttpClient) Get(rawurl string) ([]byte, error) { return nil, ErrTimeout{fmt.Errorf("Unable to fetch data. Maximum retries reached: %d", h.MaxRetries)} } + +func (h *HttpClient) Get(dataURL string) ([]byte, error) { + if resp, err := h.client.Get(dataURL); err == nil { + defer resp.Body.Close() + switch resp.StatusCode / 100 { + case HTTP_2xx: + return ioutil.ReadAll(resp.Body) + case HTTP_4xx: + return nil, ErrNotFound{fmt.Errorf("Not found. HTTP status code: %d", resp.StatusCode)} + default: + return nil, ErrServer{fmt.Errorf("Server error. HTTP status code: %d", resp.StatusCode)} + } + } else { + return nil, ErrNetwork{fmt.Errorf("Unable to fetch data: %s", err.Error())} + } +} diff --git a/pkg/http_client_test.go b/pkg/http_client_test.go index 4c246a0..2af87b4 100644 --- a/pkg/http_client_test.go +++ b/pkg/http_client_test.go @@ -51,7 +51,7 @@ func TestGetURLExpBackOff(t *testing.T) { ts := httptest.NewServer(mux) defer ts.Close() - data, err := client.Get(ts.URL) + data, err := client.GetRetry(ts.URL) if err != nil { t.Errorf("Test case %d produced error: %v", i, err) } @@ -76,7 +76,7 @@ func TestGetURL4xx(t *testing.T) { })) defer ts.Close() - _, err := client.Get(ts.URL) + _, err := client.GetRetry(ts.URL) if err == nil { t.Errorf("Incorrect result\ngot: %s\nwant: %s", err.Error(), "Not found. HTTP status code: 404") } @@ -107,7 +107,7 @@ coreos: })) defer ts.Close() - data, err := client.Get(ts.URL) + data, err := client.GetRetry(ts.URL) if err != nil { t.Errorf("Incorrect result\ngot: %v\nwant: %v", err, nil) } @@ -132,7 +132,7 @@ func TestGetMalformedURL(t *testing.T) { } for _, test := range tests { - _, err := client.Get(test.url) + _, err := client.GetRetry(test.url) if err == nil || err.Error() != test.want { t.Errorf("Incorrect result\ngot: %v\nwant: %v", err, test.want) }