Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase test coverage + Fix save_model_to_hdf5 + Improve is_remote_path + Fix is_remote_path #900

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
77629fe
Increase test coverage in `saving`
Faisal-Alsrheed Sep 16, 2023
68478f0
Add FAILED tests TODO
Faisal-Alsrheed Sep 16, 2023
656d80d
Add tests for `LambdaCallback`
Faisal-Alsrheed Sep 16, 2023
99e8418
Add tests for `LambdaCallback`
Faisal-Alsrheed Sep 16, 2023
e0a76fa
Add test for saving_api.py#L96
Faisal-Alsrheed Sep 16, 2023
4f51acf
Increase test coverage in `saving`
Faisal-Alsrheed Sep 16, 2023
f969e82
Increase test coverage
Faisal-Alsrheed Sep 16, 2023
dde85cf
refines the logic `os.makedirs` +Increase tests
Faisal-Alsrheed Sep 17, 2023
d5c9f44
Increase test coverage
Faisal-Alsrheed Sep 17, 2023
1b9a759
Increase test coverage
Faisal-Alsrheed Sep 17, 2023
7c6dbf3
More tests file_utils_test.py+fix bug `rmtree`
Faisal-Alsrheed Sep 17, 2023
bce3dc6
More tests `file_utils_test` + fix bug `rmtree`
Faisal-Alsrheed Sep 17, 2023
50d2f9a
More tests file_utils_test + fix bug rmtree
Faisal-Alsrheed Sep 17, 2023
fe0e387
Increase test coverage
Faisal-Alsrheed Sep 17, 2023
c72e412
add tests to `lambda_callback_test`
Faisal-Alsrheed Sep 17, 2023
995a336
Add tests in file_utils_test.py
Faisal-Alsrheed Sep 17, 2023
1d03842
Add tests in file_utils_test.py
Faisal-Alsrheed Sep 17, 2023
c12e933
Add more tests `file_utils_test`
Faisal-Alsrheed Sep 17, 2023
b8d7626
add class TestValidateFile
Faisal-Alsrheed Sep 17, 2023
ca4e721
Add tests for `TestIsRemotePath`
Faisal-Alsrheed Sep 17, 2023
19f5f6f
Add tests in file_utils_test.py
Faisal-Alsrheed Sep 18, 2023
dd4ee82
Add tests in file_utils_test.py
Faisal-Alsrheed Sep 18, 2023
2213296
Add tests in file_utils_test.py
Faisal-Alsrheed Sep 18, 2023
038021f
Add tests in `file_utils_test.py`
Faisal-Alsrheed Sep 18, 2023
1070aa7
fix `is_remote_path`
Faisal-Alsrheed Sep 18, 2023
9012d15
improve `is_remote_path`
Faisal-Alsrheed Sep 18, 2023
e1eb4e4
Add test for `raise_if_no_gfile_raises`
Faisal-Alsrheed Sep 18, 2023
5138512
Add tests for file_utils.py
Faisal-Alsrheed Sep 18, 2023
a567239
Add tests in `saving_api_test.py`
Faisal-Alsrheed Sep 18, 2023
81079ab
Add tests `saving_api_test.py`
Faisal-Alsrheed Sep 18, 2023
32ec488
Add tests saving_api_test.py
Faisal-Alsrheed Sep 18, 2023
8747618
Add tests in `saving_api_test.py`
Faisal-Alsrheed Sep 18, 2023
f48f07d
Add test `test_directory_creation_on_save`
Faisal-Alsrheed Sep 18, 2023
93782de
Add test `legacy_h5_format_test.py`
Faisal-Alsrheed Sep 18, 2023
aef9b2a
Flake8 for `LambdaCallbackTest`
Faisal-Alsrheed Sep 19, 2023
fe8bd58
use `get_model` and `self.get_temp_dir`
Faisal-Alsrheed Sep 20, 2023
7d755fe
Fix format
Faisal-Alsrheed Sep 20, 2023
9b3b36a
Improve `is_remote_path` + Add tests
Faisal-Alsrheed Sep 20, 2023
021068b
Fix `is_remote_path`
Faisal-Alsrheed Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 123 additions & 4 deletions keras_core/callbacks/lambda_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

class LambdaCallbackTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_LambdaCallback(self):
BATCH_SIZE = 4
def test_lambda_callback(self):
"""Test standard LambdaCallback functionalities with training."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
Expand All @@ -34,7 +35,7 @@ def test_LambdaCallback(self):
model.fit(
x,
y,
batch_size=BATCH_SIZE,
batch_size=batch_size,
validation_split=0.2,
callbacks=[lambda_log_callback],
epochs=5,
Expand All @@ -44,3 +45,121 @@ def test_LambdaCallback(self):
self.assertTrue(any("on_epoch_begin" in log for log in logs.output))
self.assertTrue(any("on_epoch_end" in log for log in logs.output))
self.assertTrue(any("on_train_end" in log for log in logs.output))

@pytest.mark.requires_trainable_backend
def test_lambda_callback_with_batches(self):
"""Test LambdaCallback's behavior with batch-level callbacks."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)
y = np.random.randn(16, 1)
lambda_log_callback = callbacks.LambdaCallback(
on_train_batch_begin=lambda batch, logs: logging.warning(
"on_train_batch_begin"
),
on_train_batch_end=lambda batch, logs: logging.warning(
"on_train_batch_end"
),
)
with self.assertLogs(level="WARNING") as logs:
model.fit(
x,
y,
batch_size=batch_size,
validation_split=0.2,
callbacks=[lambda_log_callback],
epochs=5,
verbose=0,
)
self.assertTrue(
any("on_train_batch_begin" in log for log in logs.output)
)
self.assertTrue(
any("on_train_batch_end" in log for log in logs.output)
)

@pytest.mark.requires_trainable_backend
def test_lambda_callback_with_kwargs(self):
"""Test LambdaCallback's behavior with custom defined callback."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)
y = np.random.randn(16, 1)
model.fit(
x, y, batch_size=batch_size, epochs=1, verbose=0
) # Train briefly for evaluation to work.

def custom_on_test_begin(logs):
logging.warning("custom_on_test_begin_executed")

lambda_log_callback = callbacks.LambdaCallback(
on_test_begin=custom_on_test_begin
)
with self.assertLogs(level="WARNING") as logs:
model.evaluate(
x,
y,
batch_size=batch_size,
callbacks=[lambda_log_callback],
verbose=0,
)
self.assertTrue(
any(
"custom_on_test_begin_executed" in log
for log in logs.output
)
)

@pytest.mark.requires_trainable_backend
def test_lambda_callback_no_args(self):
"""Test initializing LambdaCallback without any arguments."""
lambda_callback = callbacks.LambdaCallback()
self.assertIsInstance(lambda_callback, callbacks.LambdaCallback)

@pytest.mark.requires_trainable_backend
def test_lambda_callback_with_additional_kwargs(self):
"""Test initializing LambdaCallback with non-predefined kwargs."""

def custom_callback(logs):
pass

lambda_callback = callbacks.LambdaCallback(
custom_method=custom_callback
)
self.assertTrue(hasattr(lambda_callback, "custom_method"))

@pytest.mark.requires_trainable_backend
def test_lambda_callback_during_prediction(self):
"""Test LambdaCallback's functionality during model prediction."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)

def custom_on_predict_begin(logs):
logging.warning("on_predict_begin_executed")

lambda_callback = callbacks.LambdaCallback(
on_predict_begin=custom_on_predict_begin
)
with self.assertLogs(level="WARNING") as logs:
model.predict(
x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0
)
self.assertTrue(
any("on_predict_begin_executed" in log for log in logs.output)
)
5 changes: 2 additions & 3 deletions keras_core/legacy/saving/legacy_h5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
if not proceed:
return

# Try creating dir if not exist
dirpath = os.path.dirname(filepath)
if not os.path.exists(dirpath):
os.path.makedirs(dirpath)
if dirpath and not os.path.exists(dirpath):
os.makedirs(dirpath, exist_ok=True)

f = h5py.File(filepath, mode="w")
opened_new_file = True
Expand Down
16 changes: 16 additions & 0 deletions keras_core/legacy/saving/legacy_h5_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,19 @@ def call(self, x):

# Compare output
self.assertAllClose(ref_output, output, atol=1e-5)


@pytest.mark.requires_trainable_backend
class DirectoryCreationTest(testing.TestCase):
def test_directory_creation_on_save(self):
"""Test if directory is created on model save."""
model = get_sequential_model(keras_core)
nested_dirpath = os.path.join(
self.get_temp_dir(), "dir1", "dir2", "dir3"
)
filepath = os.path.join(nested_dirpath, "model.h5")
self.assertFalse(os.path.exists(nested_dirpath))
legacy_h5_format.save_model_to_hdf5(model, filepath)
self.assertTrue(os.path.exists(nested_dirpath))
loaded_model = legacy_h5_format.load_model_from_hdf5(filepath)
self.assertEqual(model.to_json(), loaded_model.to_json())
178 changes: 178 additions & 0 deletions keras_core/saving/saving_api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import unittest.mock as mock

import numpy as np
from absl import logging

from keras_core import layers
from keras_core.models import Sequential
from keras_core.saving import saving_api
from keras_core.testing import test_case


class SaveModelTests(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_basic_saving(self):
"""Test basic model saving and loading."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
saving_api.save_model(model, filepath)

loaded_model = saving_api.load_model(filepath)
x = np.random.uniform(size=(10, 3))
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))

def test_invalid_save_format(self):
"""Test deprecated save_format argument."""
model = self.get_model()
with self.assertRaisesRegex(
ValueError, "The `save_format` argument is deprecated"
):
saving_api.save_model(model, "model.txt", save_format=True)

def test_unsupported_arguments(self):
"""Test unsupported argument during model save."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
with self.assertRaisesRegex(
ValueError, r"The following argument\(s\) are not supported"
):
saving_api.save_model(model, filepath, random_arg=True)

def test_save_h5_format(self):
"""Test saving model in h5 format."""
model = self.get_model()
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
saving_api.save_model(model, filepath_h5)
self.assertTrue(os.path.exists(filepath_h5))
os.remove(filepath_h5)

def test_save_unsupported_extension(self):
"""Test saving model with unsupported extension."""
model = self.get_model()
with self.assertRaisesRegex(
ValueError, "Invalid filepath extension for saving"
):
saving_api.save_model(model, "model.png")


class LoadModelTests(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_basic_load(self):
"""Test basic model loading."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
saving_api.save_model(model, filepath)

loaded_model = saving_api.load_model(filepath)
x = np.random.uniform(size=(10, 3))
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))

def test_load_unsupported_format(self):
"""Test loading model with unsupported format."""
with self.assertRaisesRegex(ValueError, "File format not supported"):
saving_api.load_model("model.pkl")

def test_load_keras_not_zip(self):
"""Test loading keras file that's not a zip."""
with self.assertRaisesRegex(ValueError, "File not found"):
saving_api.load_model("not_a_zip.keras")

def test_load_h5_format(self):
"""Test loading model in h5 format."""
model = self.get_model()
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
saving_api.save_model(model, filepath_h5)
loaded_model = saving_api.load_model(filepath_h5)
x = np.random.uniform(size=(10, 3))
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
os.remove(filepath_h5)

def test_load_model_with_custom_objects(self):
"""Test loading model with custom objects."""

class CustomLayer(layers.Layer):
def call(self, inputs):
return inputs

model = Sequential([CustomLayer(input_shape=(3,))])
filepath = os.path.join(self.get_temp_dir(), "custom_model.keras")
model.save(filepath)
loaded_model = saving_api.load_model(
filepath, custom_objects={"CustomLayer": CustomLayer}
)
self.assertIsInstance(loaded_model.layers[0], CustomLayer)
os.remove(filepath)


class LoadWeightsTests(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_load_keras_weights(self):
"""Test loading keras weights."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
model.save_weights(filepath)
original_weights = model.get_weights()
model.load_weights(filepath)
loaded_weights = model.get_weights()
for orig, loaded in zip(original_weights, loaded_weights):
self.assertTrue(np.array_equal(orig, loaded))

def test_load_h5_weights_by_name(self):
"""Test loading h5 weights by name."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
model.save_weights(filepath)
with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"):
model.load_weights(filepath, by_name=True)

def test_load_weights_invalid_extension(self):
"""Test loading weights with unsupported extension."""
model = self.get_model()
with self.assertRaisesRegex(ValueError, "File format not supported"):
model.load_weights("invalid_extension.pkl")


class SaveModelTestsWarning(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_h5_deprecation_warning(self):
"""Test deprecation warning for h5 format."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.h5")

with mock.patch.object(logging, "warning") as mock_warn:
saving_api.save_model(model, filepath)
mock_warn.assert_called_once_with(
"You are saving your model as an HDF5 file via `model.save()`. "
"This file format is considered legacy. "
"We recommend using instead the native Keras format, "
"e.g. `model.save('my_model.keras')`."
)
18 changes: 14 additions & 4 deletions keras_core/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,19 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535):


def is_remote_path(filepath):
"""Returns `True` for paths that represent a remote GCS location."""
# TODO: improve generality.
if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)):
"""
Determines if a given filepath indicates a remote location.

This function checks if the filepath represents a known remote pattern
such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`)

Args:
filepath (str): The path to be checked.

Returns:
bool: True if the filepath is a recognized remote path, otherwise False
"""
if re.match(r"^(/cns|/cfs|/gcs|/hdfs|.*://).*$", str(filepath)):
return True
return False

Expand Down Expand Up @@ -445,7 +455,7 @@ def rmtree(path):
return gfile.rmtree(path)
else:
_raise_if_no_gfile(path)
return shutil.rmtree
return shutil.rmtree(path)


def listdir(path):
Expand Down
Loading