diff --git a/errors/errors.go b/errors/errors.go index da47598f..e9f0e177 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -356,6 +356,106 @@ func Retryable(err error) error { return &retryableError{err: err} } +type IsRetryableFunc func(error) bool + +var ( + RetrayableOracleErrors = []IsRetryableFunc{ + func(err error) bool { + errmsg := err.Error() + switch { + case strings.Contains(errmsg, `ORA-`): + return true + case strings.Contains(errmsg, `can not assign`): + return true + case strings.Contains(errmsg, `can't assign`): + return true + } + return false + }, + } + RetrayablePostgresErrors = []IsRetryableFunc{ + func(err error) bool { + errmsg := err.Error() + switch { + case strings.Contains(errmsg, `number of field descriptions must equal number of`): + return true + case strings.Contains(errmsg, `not a pointer`): + return true + case strings.Contains(errmsg, `values, but dst struct has only`): + return true + case strings.Contains(errmsg, `struct doesn't have corresponding row field`): + return true + case strings.Contains(errmsg, `cannot find field`): + return true + case strings.Contains(errmsg, `cannot scan`) || strings.Contains(errmsg, `cannot convert`): + return true + case strings.Contains(errmsg, `failed to connect to`): + return true + } + return false + }, + } + RetryableMicroErrors = []IsRetryableFunc{ + func(err error) bool { + switch verr := err.(type) { + case *Error: + switch verr.Code { + case 401, 403, 408, 500, 501, 502, 503, 504: + return true + default: + return false + } + case *retryableError: + return true + } + return false + }, + } + RetryableGoErrors = []IsRetryableFunc{ + func(err error) bool { + switch verr := err.(type) { + case interface{ SafeToRetry() bool }: + return verr.SafeToRetry() + case interface{ Timeout() bool }: + return verr.Timeout() + } + switch { + case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): + return true + case errors.Is(err, context.DeadlineExceeded): + return true + case errors.Is(err, io.ErrClosedPipe), errors.Is(err, io.ErrShortBuffer), errors.Is(err, io.ErrShortWrite): + return true + } + return false + }, + } + RetryableGrpcErrors = []IsRetryableFunc{ + func(err error) bool { + st, ok := status.FromError(err) + if !ok { + return false + } + switch st.Code() { + case codes.Unavailable, codes.ResourceExhausted: + return true + case codes.DeadlineExceeded: + return true + case codes.Internal: + switch { + case strings.Contains(st.Message(), `transport: received the unexpected content-type "text/html; charset=UTF-8"`): + return true + case strings.Contains(st.Message(), io.ErrUnexpectedEOF.Error()): + return true + case strings.Contains(st.Message(), `stream terminated by RST_STREAM with error code: INTERNAL_ERROR`): + return true + } + } + return false + }, + } +) + // Unwrap provides error wrapping func (e *retryableError) Unwrap() error { return e.err @@ -370,77 +470,11 @@ func (e *retryableError) Error() string { } // IsRetryable checks error for ability to retry later -func IsRetryable(err error) bool { - switch verr := err.(type) { - case *Error: - switch verr.Code { - case 401, 403, 408, 500, 501, 502, 503, 504: +func IsRetryable(err error, fns ...IsRetryableFunc) bool { + for _, fn := range fns { + if ok := fn(err); ok { return true - default: - return false - } - case *retryableError: - return true - case interface{ SafeToRetry() bool }: - return verr.SafeToRetry() - case interface{ Timeout() bool }: - return verr.Timeout() - } - - switch { - case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): - return true - case errors.Is(err, context.DeadlineExceeded): - return true - case errors.Is(err, io.ErrClosedPipe), errors.Is(err, io.ErrShortBuffer), errors.Is(err, io.ErrShortWrite): - return true - default: - st, ok := status.FromError(err) - if !ok { - errmsg := err.Error() - if strings.Contains(errmsg, `number of field descriptions must equal number of`) { - return true - } - if strings.Contains(errmsg, `not a pointer`) { - return true - } - if strings.Contains(errmsg, `values, but dst struct has only`) { - return true - } - if strings.Contains(errmsg, `struct doesn't have corresponding row field`) { - return true - } - if strings.Contains(errmsg, `cannot find field`) { - return true - } - if strings.Contains(errmsg, `cannot scan`) || strings.Contains(errmsg, `cannot convert`) { - return true - } - if strings.Contains(errmsg, `failed to connect to`) { - return true - } - - return false - } - switch st.Code() { - case codes.Unavailable, codes.ResourceExhausted: - return true - case codes.DeadlineExceeded: - return true - case codes.Internal: - if strings.Contains(st.Message(), `transport: received the unexpected content-type "text/html; charset=UTF-8"`) { - return true - } - if strings.Contains(st.Message(), io.ErrUnexpectedEOF.Error()) { - return true - } - if strings.Contains(st.Message(), `stream terminated by RST_STREAM with error code: INTERNAL_ERROR`) { - return true - } - default: - return false } } - return false } diff --git a/errors/errors_test.go b/errors/errors_test.go index 62374c84..03c264c2 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -8,6 +8,13 @@ import ( "testing" ) +func TestIsRetrayable(t *testing.T) { + err := fmt.Errorf("ORA-") + if !IsRetryable(err, RetrayableOracleErrors...) { + t.Fatalf("IsRetrayable not works") + } +} + func TestMarshalJSON(t *testing.T) { e := InternalServerError("id", "err: %v", fmt.Errorf("err: %v", `xxx: "UNIX_TIMESTAMP": invalid identifier`)) _, err := json.Marshal(e)