diff --git a/client/client.go b/client/client.go index 723af2e0..e7a8536d 100644 --- a/client/client.go +++ b/client/client.go @@ -68,6 +68,8 @@ var ( DefaultClient Client = newRpcClient() // DefaultBackoff is the default backoff function for retries DefaultBackoff = exponentialBackoff + // DefaultRetry is the default check-for-retry function for retries + DefaultRetry = alwaysRetry // DefaultRetries is the default number of times a request is tried DefaultRetries = 1 // DefaultRequestTimeout is the default request timeout diff --git a/client/options.go b/client/options.go index 7751d879..0f53c8db 100644 --- a/client/options.go +++ b/client/options.go @@ -43,6 +43,8 @@ type CallOptions struct { // Backoff func Backoff BackoffFunc + // Check if retriable func + Retry RetryFunc // Transport Dial Timeout DialTimeout time.Duration // Number of Call attempts @@ -74,6 +76,7 @@ func newOptions(options ...Option) Options { Codecs: make(map[string]codec.NewCodec), CallOptions: CallOptions{ Backoff: DefaultBackoff, + Retry: DefaultRetry, Retries: DefaultRetries, RequestTimeout: DefaultRequestTimeout, DialTimeout: transport.DefaultDialTimeout, @@ -221,6 +224,14 @@ func WithBackoff(fn BackoffFunc) CallOption { } } +// WithRetry is a CallOption which overrides that which +// set in Options.CallOptions +func WithRetry(fn RetryFunc) CallOption { + return func(o *CallOptions) { + o.Retry = fn + } +} + // WithRetries is a CallOption which overrides that which // set in Options.CallOptions func WithRetries(i int) CallOption { diff --git a/client/retry.go b/client/retry.go new file mode 100644 index 00000000..06e2d40c --- /dev/null +++ b/client/retry.go @@ -0,0 +1,11 @@ +package client + +import "golang.org/x/net/context" + +// note that returning either true or a non-nil error will result in the call not being retried +type RetryFunc func(ctx context.Context, req Request, retryCount int, err error) (bool, error) + +// always retry on error +func alwaysRetry(ctx context.Context, req Request, retryCount int, err error) (bool, error) { + return true, nil +} diff --git a/client/rpc_client.go b/client/rpc_client.go index 6e78db4a..861ddc53 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -299,6 +299,16 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac if err == nil { return nil } + + retry, rerr := callOpts.Retry(ctx, request, i, err) + if rerr != nil { + return rerr + } + + if !retry { + return err + } + gerr = err } } @@ -400,6 +410,16 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt if rsp.err == nil { return rsp.stream, nil } + + retry, rerr := callOpts.Retry(ctx, request, i, rsp.err) + if rerr != nil { + return nil, rerr + } + + if !retry { + return nil, rsp.err + } + grr = rsp.err } }