Don't recall vals everywhere

This commit is contained in:
Asim Aslam 2019-10-23 22:51:08 +01:00
parent fb3d729681
commit 3ce71e12ff

View File

@ -25,6 +25,13 @@ const apiBaseURL = "https://api.cloudflare.com/client/v4/"
type workersKV struct { type workersKV struct {
options.Options options.Options
// cf account id
account string
// cf api token
token string
// cf kv namespace
namespace string
// http client to use
httpClient *http.Client httpClient *http.Client
} }
@ -35,32 +42,37 @@ type workersKV struct {
func NewStore(opts ...options.Option) (store.Store, error) { func NewStore(opts ...options.Option) (store.Store, error) {
// Validate Options // Validate Options
options := options.NewOptions(opts...) options := options.NewOptions(opts...)
var account, token, namespace string
apiToken, ok := options.Values().Get("CF_API_TOKEN") apiToken, ok := options.Values().Get("CF_API_TOKEN")
if !ok { if !ok {
log.Fatal("Store: No CF_API_TOKEN passed as an option") log.Fatal("Store: No CF_API_TOKEN passed as an option")
} }
_, ok = apiToken.(string) if token, ok = apiToken.(string); !ok {
if !ok {
log.Fatal("Store: Option CF_API_TOKEN contains a non-string") log.Fatal("Store: Option CF_API_TOKEN contains a non-string")
} }
accountID, ok := options.Values().Get("CF_ACCOUNT_ID") accountID, ok := options.Values().Get("CF_ACCOUNT_ID")
if !ok { if !ok {
log.Fatal("Store: No CF_ACCOUNT_ID passed as an option") log.Fatal("Store: No CF_ACCOUNT_ID passed as an option")
} }
_, ok = accountID.(string) if account, ok = accountID.(string); !ok {
if !ok {
log.Fatal("Store: Option CF_ACCOUNT_ID contains a non-string") log.Fatal("Store: Option CF_ACCOUNT_ID contains a non-string")
} }
uuid, ok := options.Values().Get("KV_NAMESPACE_ID") uuid, ok := options.Values().Get("KV_NAMESPACE_ID")
if !ok { if !ok {
log.Fatal("Store: No KV_NAMESPACE_ID passed as an option") log.Fatal("Store: No KV_NAMESPACE_ID passed as an option")
} }
_, ok = uuid.(string) if namespace, ok = uuid.(string); !ok {
if !ok {
log.Fatal("Store: Option KV_NAMESPACE_ID contains a non-string") log.Fatal("Store: Option KV_NAMESPACE_ID contains a non-string")
} }
return &workersKV{ return &workersKV{
account: account,
namespace: namespace,
token: token,
Options: options, Options: options,
httpClient: &http.Client{}, httpClient: &http.Client{},
}, nil }, nil
@ -72,18 +84,18 @@ func (w *workersKV) List() ([]*store.Record, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
accountID, _ := w.Options.Values().Get("CF_ACCOUNT_ID") path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/keys", w.account, w.namespace)
kvID, _ := w.Options.Values().Get("KV_NAMESPACE_ID")
path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/keys", accountID.(string), kvID.(string))
response, _, _, err := w.request(ctx, http.MethodGet, path, nil, make(http.Header)) response, _, _, err := w.request(ctx, http.MethodGet, path, nil, make(http.Header))
if err != nil { if err != nil {
return nil, err return nil, err
} }
a := &apiResponse{} a := &apiResponse{}
if err := json.Unmarshal(response, a); err != nil { if err := json.Unmarshal(response, a); err != nil {
return nil, err return nil, err
} }
if !a.Success { if !a.Success {
messages := "" messages := ""
for _, m := range a.Errors { for _, m := range a.Errors {
@ -93,9 +105,11 @@ func (w *workersKV) List() ([]*store.Record, error) {
} }
var keys []string var keys []string
for _, r := range a.Result { for _, r := range a.Result {
keys = append(keys, r.Name) keys = append(keys, r.Name)
} }
return w.Read(keys...) return w.Read(keys...)
} }
@ -103,12 +117,10 @@ func (w *workersKV) Read(keys ...string) ([]*store.Record, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
accountID, _ := w.Options.Values().Get("CF_ACCOUNT_ID")
kvID, _ := w.Options.Values().Get("KV_NAMESPACE_ID")
var records []*store.Record var records []*store.Record
for _, k := range keys { for _, k := range keys {
path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/values/%s", accountID.(string), kvID.(string), url.PathEscape(k)) path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/values/%s", w.account, w.namespace, url.PathEscape(k))
response, headers, status, err := w.request(ctx, http.MethodGet, path, nil, make(http.Header)) response, headers, status, err := w.request(ctx, http.MethodGet, path, nil, make(http.Header))
if err != nil { if err != nil {
return records, err return records, err
@ -129,6 +141,7 @@ func (w *workersKV) Read(keys ...string) ([]*store.Record, error) {
} }
records = append(records, record) records = append(records, record)
} }
return records, nil return records, nil
} }
@ -136,25 +149,26 @@ func (w *workersKV) Write(records ...*store.Record) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
accountID, _ := w.Options.Values().Get("CF_ACCOUNT_ID")
kvID, _ := w.Options.Values().Get("KV_NAMESPACE_ID")
for _, r := range records { for _, r := range records {
path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/values/%s", accountID.(string), kvID.(string), url.PathEscape(r.Key)) path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/values/%s", w.account, w.namespace, url.PathEscape(r.Key))
if r.Expiry != 0 { if r.Expiry != 0 {
// Minimum cloudflare TTL is 60 Seconds // Minimum cloudflare TTL is 60 Seconds
exp := int(math.Max(60, math.Round(r.Expiry.Seconds()))) exp := int(math.Max(60, math.Round(r.Expiry.Seconds())))
path = path + "?expiration_ttl=" + strconv.Itoa(exp) path = path + "?expiration_ttl=" + strconv.Itoa(exp)
} }
headers := make(http.Header) headers := make(http.Header)
resp, _, _, err := w.request(ctx, http.MethodPut, path, r.Value, headers) resp, _, _, err := w.request(ctx, http.MethodPut, path, r.Value, headers)
if err != nil { if err != nil {
return err return err
} }
a := &apiResponse{} a := &apiResponse{}
if err := json.Unmarshal(resp, a); err != nil { if err := json.Unmarshal(resp, a); err != nil {
return err return err
} }
if !a.Success { if !a.Success {
messages := "" messages := ""
for _, m := range a.Errors { for _, m := range a.Errors {
@ -163,6 +177,7 @@ func (w *workersKV) Write(records ...*store.Record) error {
return errors.New(messages) return errors.New(messages)
} }
} }
return nil return nil
} }
@ -170,11 +185,8 @@ func (w *workersKV) Delete(keys ...string) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
accountID, _ := w.Options.Values().Get("CF_ACCOUNT_ID")
kvID, _ := w.Options.Values().Get("KV_NAMESPACE_ID")
for _, k := range keys { for _, k := range keys {
path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/values/%s", accountID.(string), kvID.(string), url.PathEscape(k)) path := fmt.Sprintf("accounts/%s/storage/kv/namespaces/%s/values/%s", w.account, w.namespace, url.PathEscape(k))
resp, _, _, err := w.request(ctx, http.MethodDelete, path, nil, make(http.Header)) resp, _, _, err := w.request(ctx, http.MethodDelete, path, nil, make(http.Header))
if err != nil { if err != nil {
return err return err
@ -184,6 +196,7 @@ func (w *workersKV) Delete(keys ...string) error {
if err := json.Unmarshal(resp, a); err != nil { if err := json.Unmarshal(resp, a); err != nil {
return err return err
} }
if !a.Success { if !a.Success {
messages := "" messages := ""
for _, m := range a.Errors { for _, m := range a.Errors {
@ -192,6 +205,7 @@ func (w *workersKV) Delete(keys ...string) error {
return errors.New(messages) return errors.New(messages)
} }
} }
return nil return nil
} }
@ -211,29 +225,39 @@ func (w *workersKV) request(ctx context.Context, method, path string, body inter
} else { } else {
jsonBody = nil jsonBody = nil
} }
var reqBody io.Reader var reqBody io.Reader
if jsonBody != nil { if jsonBody != nil {
reqBody = bytes.NewReader(jsonBody) reqBody = bytes.NewReader(jsonBody)
} }
req, err := http.NewRequestWithContext(ctx, method, apiBaseURL+path, reqBody) req, err := http.NewRequestWithContext(ctx, method, apiBaseURL+path, reqBody)
for key, value := range headers { for key, value := range headers {
req.Header[key] = value req.Header[key] = value
} }
if token, found := w.Options.Values().Get("CF_API_TOKEN"); found {
req.Header.Set("Authorization", "Bearer "+token.(string)) // set token if it exists
if len(w.token) > 0 {
req.Header.Set("Authorization", "Bearer "+w.token)
} }
// set the user agent to micro
req.Header.Set("User-Agent", "micro/1.0 (https://micro.mu)") req.Header.Set("User-Agent", "micro/1.0 (https://micro.mu)")
// Official cloudflare client does exponential backoff here // Official cloudflare client does exponential backoff here
// TODO: retry and use util/backoff
resp, err := w.httpClient.Do(req) resp, err := w.httpClient.Do(req)
if err != nil { if err != nil {
return nil, nil, 0, err return nil, nil, 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := ioutil.ReadAll(resp.Body) respBody, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return respBody, resp.Header, resp.StatusCode, err return respBody, resp.Header, resp.StatusCode, err
} }
return respBody, resp.Header, resp.StatusCode, nil return respBody, resp.Header, resp.StatusCode, nil
} }