Skip to content

Commit

Permalink
Flake8 for LambdaCallbackTest
Browse files Browse the repository at this point in the history
  • Loading branch information
Faisal-Alsrheed committed Sep 19, 2023
1 parent 93782de commit aef9b2a
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions keras_core/callbacks/lambda_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

class LambdaCallbackTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_LambdaCallback(self):
def test_lambda_callback(self):
"""Test standard LambdaCallback functionalities with training."""
BATCH_SIZE = 4
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
Expand All @@ -35,27 +35,23 @@ def test_LambdaCallback(self):
model.fit(
x,
y,
batch_size=BATCH_SIZE,
batch_size=batch_size,
validation_split=0.2,
callbacks=[lambda_log_callback],
epochs=5,
verbose=0,
)
self.assertTrue
(any("on_train_begin" in log for log in logs.output))
self.assertTrue
(any("on_epoch_begin" in log for log in logs.output))
self.assertTrue
(any("on_epoch_end" in log for log in logs.output))
self.assertTrue
(any("on_train_end" in log for log in logs.output))
self.assertTrue(any("on_train_begin" in log for log in logs.output))
self.assertTrue(any("on_epoch_begin" in log for log in logs.output))
self.assertTrue(any("on_epoch_end" in log for log in logs.output))
self.assertTrue(any("on_train_end" in log for log in logs.output))

@pytest.mark.requires_trainable_backend
def test_LambdaCallback_with_batches(self):
def test_lambda_callback_with_batches(self):
"""Test LambdaCallback's behavior with batch-level callbacks."""
BATCH_SIZE = 4
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
Expand All @@ -74,7 +70,7 @@ def test_LambdaCallback_with_batches(self):
model.fit(
x,
y,
batch_size=BATCH_SIZE,
batch_size=batch_size,
validation_split=0.2,
callbacks=[lambda_log_callback],
epochs=5,
Expand All @@ -88,19 +84,19 @@ def test_LambdaCallback_with_batches(self):
)

@pytest.mark.requires_trainable_backend
def test_LambdaCallback_with_kwargs(self):
def test_lambda_callback_with_kwargs(self):
"""Test LambdaCallback's behavior with custom defined callback."""
BATCH_SIZE = 4
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)
y = np.random.randn(16, 1)
model.fit(
x, y, batch_size=BATCH_SIZE, epochs=1, verbose=0
x, y, batch_size=batch_size, epochs=1, verbose=0
) # Train briefly for evaluation to work.

def custom_on_test_begin(logs):
Expand All @@ -113,7 +109,7 @@ def custom_on_test_begin(logs):
model.evaluate(
x,
y,
batch_size=BATCH_SIZE,
batch_size=batch_size,
callbacks=[lambda_log_callback],
verbose=0,
)
Expand All @@ -125,13 +121,13 @@ def custom_on_test_begin(logs):
)

@pytest.mark.requires_trainable_backend
def test_LambdaCallback_no_args(self):
def test_lambda_callback_no_args(self):
"""Test initializing LambdaCallback without any arguments."""
lambda_callback = callbacks.LambdaCallback()
self.assertIsInstance(lambda_callback, callbacks.LambdaCallback)

@pytest.mark.requires_trainable_backend
def test_LambdaCallback_with_additional_kwargs(self):
def test_lambda_callback_with_additional_kwargs(self):
"""Test initializing LambdaCallback with non-predefined kwargs."""

def custom_callback(logs):
Expand All @@ -143,11 +139,11 @@ def custom_callback(logs):
self.assertTrue(hasattr(lambda_callback, "custom_method"))

@pytest.mark.requires_trainable_backend
def test_LambdaCallback_during_prediction(self):
def test_lambda_callback_during_prediction(self):
"""Test LambdaCallback's functionality during model prediction."""
BATCH_SIZE = 4
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
Expand All @@ -162,7 +158,7 @@ def custom_on_predict_begin(logs):
)
with self.assertLogs(level="WARNING") as logs:
model.predict(
x, batch_size=BATCH_SIZE, callbacks=[lambda_callback], verbose=0
x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0
)
self.assertTrue(
any("on_predict_begin_executed" in log for log in logs.output)
Expand Down

0 comments on commit aef9b2a

Please sign in to comment.