diff --git a/client/retry.go b/client/retry.go index df74abec..59743980 100644 --- a/client/retry.go +++ b/client/retry.go @@ -1,8 +1,10 @@ package client -type RetryFunc func(err error) bool +import "context" + +type RetryFunc func(ctx context.Context, req Request, retryCount int, err error) (bool, error) // always retry on error -func alwaysRetry(err error) bool { - return true +func alwaysRetry(ctx context.Context, req Request, retryCount int, err error) (bool, error) { + return true, err } diff --git a/client/rpc_client.go b/client/rpc_client.go index 0ad84805..ede9e899 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -300,7 +300,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac return nil } - if !callOpts.Retry(err) { + if retry, err := callOpts.Retry(ctx, request, i, err); !retry { return err } @@ -405,7 +405,12 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt if rsp.err == nil { return rsp.stream, nil } - grr = rsp.err + + if retry, err := callOpts.Retry(ctx, request, i, rsp.err); !retry { + return nil, err + } + + grr = err } }