diff --git a/core/google/api/core/retry.py b/core/google/api/core/retry.py index a0ffe8542b9b..d772b333556c 100644 --- a/core/google/api/core/retry.py +++ b/core/google/api/core/retry.py @@ -135,7 +135,7 @@ def exponential_sleep_generator( delay = delay * multiplier -def retry_target(target, predicate, sleep_generator, deadline): +def retry_target(target, predicate, sleep_generator, deadline, on_error=None): """Call a function and retry if it fails. This is the lowest-level retry helper. Generally, you'll use the @@ -150,6 +150,9 @@ def retry_target(target, predicate, sleep_generator, deadline): sleep_generator (Iterator[float]): An infinite iterator that determines how long to sleep between retries. deadline (float): How long to keep retrying the target. + on_error (Callable): A function to call while processing a retryable + exception. Any error raised by this function will *not* be + caught. Returns: Any: the return value of the target function. @@ -177,6 +180,8 @@ def retry_target(target, predicate, sleep_generator, deadline): if not predicate(exc): raise last_exc = exc + if on_error is not None: + on_error(exc) now = datetime_helpers.utcnow() if deadline_datetime is not None and deadline_datetime < now: @@ -226,11 +231,14 @@ def __init__( self._maximum = maximum self._deadline = deadline - def __call__(self, func): + def __call__(self, func, on_error=None): """Wrap a callable with retry behavior. Args: func (Callable): The callable to add retry behavior to. + on_error (Callable): A function to call while processing a + retryable exception. Any error raised by this function will + *not* be caught. Returns: Callable: A callable that will invoke ``func`` with retry @@ -246,7 +254,9 @@ def retry_wrapped_func(*args, **kwargs): target, self._predicate, sleep_generator, - self._deadline) + self._deadline, + on_error=on_error, + ) return retry_wrapped_func diff --git a/core/tests/unit/api_core/test_retry.py b/core/tests/unit/api_core/test_retry.py index 71569137b94f..432ccb1438db 100644 --- a/core/tests/unit/api_core/test_retry.py +++ b/core/tests/unit/api_core/test_retry.py @@ -77,6 +77,34 @@ def target(): sleep.assert_has_calls([mock.call(0), mock.call(1)]) +@mock.patch('time.sleep', autospec=True) +@mock.patch( + 'google.api.core.helpers.datetime_helpers.utcnow', + return_value=datetime.datetime.min, + autospec=True) +def test_retry_target_w_on_error(utcnow, sleep): + predicate = retry.if_exception_type(ValueError) + call_count = {'target': 0} + to_raise = ValueError() + + def target(): + call_count['target'] += 1 + if call_count['target'] < 3: + raise to_raise + return 42 + + on_error = mock.Mock() + + result = retry.retry_target( + target, predicate, range(10), None, on_error=on_error) + + assert result == 42 + assert call_count['target'] == 3 + + on_error.assert_has_calls([mock.call(to_raise), mock.call(to_raise)]) + sleep.assert_has_calls([mock.call(0), mock.call(1)]) + + @mock.patch('time.sleep', autospec=True) @mock.patch( 'google.api.core.helpers.datetime_helpers.utcnow', @@ -139,7 +167,8 @@ def test_constructor_options(self): initial=1, maximum=2, multiplier=3, - deadline=4) + deadline=4, + ) assert retry_._predicate == mock.sentinel.predicate assert retry_._initial == 1 assert retry_._maximum == 2 @@ -204,12 +233,17 @@ def test___call___and_execute_success(self, sleep): 'random.uniform', autospec=True, side_effect=lambda m, n: n/2.0) @mock.patch('time.sleep', autospec=True) def test___call___and_execute_retry(self, sleep, uniform): - retry_ = retry.Retry(predicate=retry.if_exception_type(ValueError)) + + on_error = mock.Mock(spec=['__call__'], side_effect=[None]) + retry_ = retry.Retry( + predicate=retry.if_exception_type(ValueError), + ) + target = mock.Mock(spec=['__call__'], side_effect=[ValueError(), 42]) # __name__ is needed by functools.partial. target.__name__ = 'target' - decorated = retry_(target) + decorated = retry_(target, on_error=on_error) target.assert_not_called() result = decorated('meep') @@ -218,3 +252,4 @@ def test___call___and_execute_retry(self, sleep, uniform): assert target.call_count == 2 target.assert_has_calls([mock.call('meep'), mock.call('meep')]) sleep.assert_called_once_with(retry_._initial) + assert on_error.call_count == 1