diff --git a/datasource/datasource.go b/datasource/datasource.go index f6175a9..42d9551 100644 --- a/datasource/datasource.go +++ b/datasource/datasource.go @@ -2,26 +2,26 @@ package datasource import ( "io/ioutil" + "log" + "math" "net/http" + "time" ) +const maxTimeout = time.Second * 5 + type Datasource interface { Fetch() ([]byte, error) - Type() string + Type() string } func fetchURL(url string) ([]byte, error) { - client := http.Client{} - resp, err := client.Get(url) + resp, err := getWithExponentialBackoff(url) if err != nil { return []byte{}, err } defer resp.Body.Close() - if resp.StatusCode / 100 != 2 { - return []byte{}, nil - } - respBytes, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err @@ -29,3 +29,25 @@ func fetchURL(url string) ([]byte, error) { return respBytes, nil } + +// getWithExponentialBackoff issues a GET to the specified URL. If the +// response is a non-2xx or produces an error, retry the GET forever using +// an exponential backoff. +func getWithExponentialBackoff(url string) (*http.Response, error) { + var err error + var resp *http.Response + for i := 0; ; i++ { + resp, err = http.Get(url) + if err == nil && resp.StatusCode/100 == 2 { + return resp, nil + } + duration := time.Millisecond * time.Duration((math.Pow(float64(2), float64(i)) * 100)) + if duration > maxTimeout { + duration = maxTimeout + } + + log.Printf("unable to fetch user-data from %s, try again in %s", url, duration) + time.Sleep(duration) + } + return resp, err +} diff --git a/datasource/datasource_test.go b/datasource/datasource_test.go new file mode 100644 index 0000000..19ccd53 --- /dev/null +++ b/datasource/datasource_test.go @@ -0,0 +1,45 @@ +package datasource + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +var expBackoffTests = []struct { + count int + body string +}{ + {0, "number of attempts: 0"}, + {1, "number of attempts: 1"}, + {2, "number of attempts: 2"}, +} + +func TestGetWithExponentialBackoff(t *testing.T) { + for i, tt := range expBackoffTests { + mux := http.NewServeMux() + count := 0 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if count == tt.count { + io.WriteString(w, fmt.Sprintf("number of attempts: %d", count)) + return + } + count++ + http.Error(w, "", 500) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + data, err := fetchURL(ts.URL) + if err != nil { + t.Errorf("Test case %d produced error: %v", i, err) + } + if count != tt.count { + t.Errorf("Test case %d failed: %d != %d", i, count, tt.count) + } + if string(data) != tt.body { + t.Errorf("Test case %d failed: %s != %s", i, tt.body, data) + } + } +}