diff --git a/raiutils/raiutils/common/retries.py b/raiutils/raiutils/common/retries.py index 690da259b0..ee169d3b48 100644 --- a/raiutils/raiutils/common/retries.py +++ b/raiutils/raiutils/common/retries.py @@ -33,7 +33,11 @@ def retry_function(function, action_name, err_msg, print(e) if i + 1 != max_retries: print("Will retry after {0} seconds".format(retry_delay)) - time.sleep(retry_delay) + try: + for _ in range(retry_delay): + time.sleep(1) + except TypeError: + time.sleep(retry_delay) retry_delay = retry_delay * 2 else: raise RuntimeError(err_msg) diff --git a/raiutils/tests/test_retry_func.py b/raiutils/tests/test_retry_func.py new file mode 100644 index 0000000000..b518dc61fd --- /dev/null +++ b/raiutils/tests/test_retry_func.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. +import time + +import pytest + +from raiutils.common.retries import retry_function + + +class TestRetryFunction: + _DELTA = 1.0 + + def test_no_error(self): + x = 5 + + def func(): + return x + 1 + + result = retry_function(func, 'test', 'test failed') + assert result == 6 + + result = retry_function(func, 'test', 'test failed', retry_delay=1) + assert result == 6 + + result = retry_function(func, 'test', 'test failed', retry_delay=1.1) + assert result == 6 + + result = retry_function(func, 'test', 'test failed', retry_delay=0) + assert result == 6 + + def test_error_with_int_delay(self): + x = 'a' + + def func(): + return x + 1 + + t_start = time.time() + with pytest.raises(RuntimeError, match='test failed'): + retry_function(func, 'test', 'test failed', + max_retries=4, retry_delay=1) + time_taken = time.time() - t_start + expected_time_taken = 1 + 2 + 4 + assert abs(time_taken - expected_time_taken) < self._DELTA + + def test_error_with_zero_delay(self): + x = 'a' + + def func(): + return x + 1 + + t_start = time.time() + with pytest.raises(RuntimeError, match='test failed'): + retry_function(func, 'test', 'test failed', + max_retries=4, retry_delay=0) + time_taken = time.time() - t_start + expected_time_taken = 0 + assert abs(time_taken - expected_time_taken) < self._DELTA + + def test_error_with_float_delay(self): + x = 'a' + + def func(): + return x + 1 + + t_start = time.time() + with pytest.raises(RuntimeError, match='test failed'): + retry_function(func, 'test', 'test failed', + max_retries=4, retry_delay=0.5) + time_taken = time.time() - t_start + expected_time_taken = 0.5 + 1.0 + 2.0 + assert abs(time_taken - expected_time_taken) < self._DELTA