Skip to content

Commit

Permalink
Fixed: keras-team#1722 Run out of memory
Browse files Browse the repository at this point in the history
Update utils.py and test by using `break` instead of `raise`
  • Loading branch information
Anselmoo committed May 14, 2022
1 parent c51da2d commit f477215
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
9 changes: 7 additions & 2 deletions autokeras/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,20 @@ def fit_with_adaptive_batch_size(model, batch_size, **fit_kwargs):
def run_with_adaptive_batch_size(batch_size, func, **fit_kwargs):
x = fit_kwargs.pop("x")
validation_data = None
history = None
if "validation_data" in fit_kwargs:
validation_data = fit_kwargs.pop("validation_data")
while batch_size > 0:
try:
history = func(x=x, validation_data=validation_data, **fit_kwargs)
break
except tf.errors.ResourceExhaustedError as e:
except tf.errors.ResourceExhaustedError:
if batch_size == 1:
raise e
print(
"Not enough memory, reduced batch size is already set to 1. "
"Current model will be skipped."
)
break
batch_size //= 2
print(
"Not enough memory, reduce batch size to {batch_size}.".format(
Expand Down
21 changes: 11 additions & 10 deletions autokeras/utils/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,20 @@ def test_check_kt_version_error():
)


def test_run_with_adaptive_batch_size_raise_error():
def test_run_with_adaptive_batch_size_raise_error(capfd):
def func(**kwargs):
raise tf.errors.ResourceExhaustedError(0, "", None)

with pytest.raises(tf.errors.ResourceExhaustedError):
utils.run_with_adaptive_batch_size(
batch_size=64,
func=func,
x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64),
validation_data=tf.data.Dataset.from_tensor_slices(
np.random.rand(100, 1)
).batch(64),
)
utils.run_with_adaptive_batch_size(
batch_size=64,
func=func,
x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64),
validation_data=tf.data.Dataset.from_tensor_slices(
np.random.rand(100, 1)
).batch(64),
)
std, _ = capfd.readouterr()
assert "Not enough memory" in std


def test_get_hyperparameter_with_none_return_hp():
Expand Down

0 comments on commit f477215

Please sign in to comment.