From 77629fee45aeb7deb4c64b6aeec8975ba08e1304 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 09:32:45 +0000 Subject: [PATCH 01/39] Increase test coverage in `saving` --- keras_core/legacy/saving/legacy_h5_format.py | 5 +- keras_core/saving/saving_api_test.py | 139 +++++++++++++++++++ 2 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 keras_core/saving/saving_api_test.py diff --git a/keras_core/legacy/saving/legacy_h5_format.py b/keras_core/legacy/saving/legacy_h5_format.py index 9f4753fd2..ca5660d6a 100644 --- a/keras_core/legacy/saving/legacy_h5_format.py +++ b/keras_core/legacy/saving/legacy_h5_format.py @@ -37,10 +37,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): if not proceed: return - # Try creating dir if not exist dirpath = os.path.dirname(filepath) - if not os.path.exists(dirpath): - os.path.makedirs(dirpath) + if dirpath and not os.path.exists(dirpath): + os.makedirs(dirpath) f = h5py.File(filepath, mode="w") opened_new_file = True diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py new file mode 100644 index 000000000..ff686d192 --- /dev/null +++ b/keras_core/saving/saving_api_test.py @@ -0,0 +1,139 @@ +import os +import unittest + +import numpy as np + +from keras_core import layers +from keras_core.models import Sequential +from keras_core.saving import saving_api + + +class SaveModelTests(unittest.TestCase): + def setUp(self): + self.model = Sequential( + [ + layers.Dense(5, input_shape=(3,)), + layers.Softmax(), + ], + ) + self.filepath = "test_model.keras" + saving_api.save_model(self.model, self.filepath) + + def test_basic_saving(self): + loaded_model = saving_api.load_model(self.filepath) + x = np.random.uniform(size=(10, 3)) + self.assertTrue( + np.allclose(self.model.predict(x), loaded_model.predict(x)) + ) + + def test_invalid_save_format(self): + with self.assertRaisesRegex( + ValueError, "The `save_format` argument is deprecated" + ): + saving_api.save_model(self.model, "model.txt", save_format=True) + + def test_overwrite_prompt(self): + original_mtime = os.path.getmtime(self.filepath) + saving_api.io_utils.ask_to_proceed_with_overwrite = lambda x: False + saving_api.save_model(self.model, self.filepath, overwrite=False) + new_mtime = os.path.getmtime(self.filepath) + self.assertEqual(original_mtime, new_mtime) + + def test_unsupported_arguments(self): + with self.assertRaises(ValueError): + saving_api.save_model(self.model, self.filepath, random_arg=True) + + def test_save_h5_format(self): + filepath_h5 = "test_model.h5" + saving_api.save_model(self.model, filepath_h5) + self.assertTrue(os.path.exists(filepath_h5)) + os.remove(filepath_h5) # Cleanup + + def test_save_unsupported_extension(self): + with self.assertRaises(ValueError): + saving_api.save_model(self.model, "model.png") + + def tearDown(self): + if os.path.exists(self.filepath): + os.remove(self.filepath) + + +class LoadModelTests(unittest.TestCase): + def setUp(self): + self.model = Sequential( + [ + layers.Dense(5, input_shape=(3,)), + layers.Softmax(), + ], + ) + self.filepath = "test_model.keras" + saving_api.save_model(self.model, self.filepath) + + def test_basic_load(self): + loaded_model = saving_api.load_model(self.filepath) + x = np.random.uniform(size=(10, 3)) + self.assertTrue( + np.allclose(self.model.predict(x), loaded_model.predict(x)) + ) + + def test_load_unsupported_format(self): + with self.assertRaises(ValueError): + saving_api.load_model("model.pkl") + + def test_load_keras_not_zip(self): + with self.assertRaises(ValueError): + saving_api.load_model("not_a_zip.keras") + + def test_load_h5_format(self): + filepath_h5 = "test_model.h5" + saving_api.save_model(self.model, filepath_h5) + loaded_model = saving_api.load_model(filepath_h5) + x = np.random.uniform(size=(10, 3)) + self.assertTrue( + np.allclose(self.model.predict(x), loaded_model.predict(x)) + ) + os.remove(filepath_h5) # Cleanup + + def tearDown(self): + if os.path.exists(self.filepath): + os.remove(self.filepath) + + +class LoadWeightsTests(unittest.TestCase): + def setUp(self): + self.model = Sequential( + [ + layers.Dense(5, input_shape=(3,)), + layers.Softmax(), + ], + ) + + def test_load_keras_weights(self): + filepath = "test_weights.weights.h5" + self.model.save_weights(filepath) + original_weights = self.model.get_weights() + self.model.load_weights(filepath) + loaded_weights = self.model.get_weights() + for orig, loaded in zip(original_weights, loaded_weights): + self.assertTrue(np.array_equal(orig, loaded)) + + def test_load_unsupported_format(self): + with self.assertRaises(ValueError): + self.model.load_weights("weights.pkl") + + def test_load_keras_format_weights(self): + filepath_keras = "test_weights.weights.h5" + self.model.save_weights(filepath_keras) + self.model.load_weights(filepath_keras) + os.remove(filepath_keras) # Cleanup + + def test_load_h5_format_weights(self): + filepath_h5 = "test_weights.weights.h5" + self.model.save_weights(filepath_h5) + self.model.load_weights(filepath_h5) + os.remove(filepath_h5) # Cleanup + + def tearDown(self): + filepath = "test_weights.weights.h5" + if os.path.exists(filepath): + os.remove(filepath) From 68478f02293bf51703e56b380ab41c6ea458f269 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 11:17:16 +0000 Subject: [PATCH 02/39] Add FAILED tests TODO --- keras_core/utils/io_utils_test.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/keras_core/utils/io_utils_test.py b/keras_core/utils/io_utils_test.py index bf57f5ca2..3ed3dac0b 100644 --- a/keras_core/utils/io_utils_test.py +++ b/keras_core/utils/io_utils_test.py @@ -22,9 +22,11 @@ def test_set_logging_verbosity_invalid(self): with self.assertRaises(ValueError): io_utils.set_logging_verbosity("INVALID") - @patch("builtins.input", side_effect=["y"]) - def test_ask_to_proceed_with_overwrite_yes(self, _): - self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + # TODO used to work but now it doesn't afte + # commit 77629fee45aeb7deb4c64b6aeec8975ba08e1304 + # @patch("builtins.input", side_effect=["y"]) + # def test_ask_to_proceed_with_overwrite_yes(self, _): + # self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) @patch("builtins.input", side_effect=["n"]) def test_ask_to_proceed_with_overwrite_no(self, _): @@ -48,9 +50,11 @@ def test_print_msg_non_interactive(self, mock_logging): io_utils.print_msg("Hello") mock_logging.assert_called_once_with("Hello") - @patch("builtins.input", side_effect=["invalid", "invalid", "y"]) - def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _): - self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + # TODO used to work but now it doesn't afte + # commit 77629fee45aeb7deb4c64b6aeec8975ba08e1304 + # @patch("builtins.input", side_effect=["invalid", "invalid", "y"]) + # def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _): + # self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) @patch("builtins.input", side_effect=["invalid", "n"]) def test_ask_to_proceed_with_overwrite_invalid_then_no(self, _): From 656d80dd1c7138e439b817b617d20aee46b661c3 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 12:00:50 +0000 Subject: [PATCH 03/39] Add tests for `LambdaCallback` --- keras_core/callbacks/lambda_callback_test.py | 72 ++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/keras_core/callbacks/lambda_callback_test.py b/keras_core/callbacks/lambda_callback_test.py index 082f0f48f..d6188112a 100644 --- a/keras_core/callbacks/lambda_callback_test.py +++ b/keras_core/callbacks/lambda_callback_test.py @@ -44,3 +44,75 @@ def test_LambdaCallback(self): self.assertTrue(any("on_epoch_begin" in log for log in logs.output)) self.assertTrue(any("on_epoch_end" in log for log in logs.output)) self.assertTrue(any("on_train_end" in log for log in logs.output)) + + @pytest.mark.requires_trainable_backend + def test_LambdaCallback_with_batches(self): + BATCH_SIZE = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + y = np.random.randn(16, 1) + lambda_log_callback = callbacks.LambdaCallback( + on_train_batch_begin=lambda batch, logs: logging.warning( + "on_train_batch_begin" + ), + on_train_batch_end=lambda batch, logs: logging.warning( + "on_train_batch_end" + ), + ) + with self.assertLogs(level="WARNING") as logs: + model.fit( + x, + y, + batch_size=BATCH_SIZE, + validation_split=0.2, + callbacks=[lambda_log_callback], + epochs=5, + verbose=0, + ) + self.assertTrue( + any("on_train_batch_begin" in log for log in logs.output) + ) + self.assertTrue( + any("on_train_batch_end" in log for log in logs.output) + ) + + @pytest.mark.requires_trainable_backend + def test_LambdaCallback_with_kwargs(self): + BATCH_SIZE = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + y = np.random.randn(16, 1) + model.fit( + x, y, batch_size=BATCH_SIZE, epochs=1, verbose=0 + ) # Train briefly for evaluation to work. + + custom_on_test_begin = lambda logs: logging.warning( + "custom_on_test_begin_executed" + ) + lambda_log_callback = callbacks.LambdaCallback( + on_test_begin=custom_on_test_begin + ) + with self.assertLogs(level="WARNING") as logs: + model.evaluate( + x, + y, + batch_size=BATCH_SIZE, + callbacks=[lambda_log_callback], + verbose=0, + ) + self.assertTrue( + any( + "custom_on_test_begin_executed" in log + for log in logs.output + ) + ) From 99e8418bfe5387f2330df41cfb7148a2af080f17 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 12:14:43 +0000 Subject: [PATCH 04/39] Add tests for `LambdaCallback` --- keras_core/callbacks/lambda_callback_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_core/callbacks/lambda_callback_test.py b/keras_core/callbacks/lambda_callback_test.py index d6188112a..d2f746f70 100644 --- a/keras_core/callbacks/lambda_callback_test.py +++ b/keras_core/callbacks/lambda_callback_test.py @@ -96,9 +96,10 @@ def test_LambdaCallback_with_kwargs(self): x, y, batch_size=BATCH_SIZE, epochs=1, verbose=0 ) # Train briefly for evaluation to work. - custom_on_test_begin = lambda logs: logging.warning( - "custom_on_test_begin_executed" - ) + # Replacing lambda with a proper function definition + def custom_on_test_begin(logs): + logging.warning("custom_on_test_begin_executed") + lambda_log_callback = callbacks.LambdaCallback( on_test_begin=custom_on_test_begin ) From e0a76fa29f19f21c9c29079cba184fc5d21c780d Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 13:03:18 +0000 Subject: [PATCH 05/39] Add test for saving_api.py#L96 --- keras_core/utils/io_utils_test.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/keras_core/utils/io_utils_test.py b/keras_core/utils/io_utils_test.py index 3ed3dac0b..bf57f5ca2 100644 --- a/keras_core/utils/io_utils_test.py +++ b/keras_core/utils/io_utils_test.py @@ -22,11 +22,9 @@ def test_set_logging_verbosity_invalid(self): with self.assertRaises(ValueError): io_utils.set_logging_verbosity("INVALID") - # TODO used to work but now it doesn't afte - # commit 77629fee45aeb7deb4c64b6aeec8975ba08e1304 - # @patch("builtins.input", side_effect=["y"]) - # def test_ask_to_proceed_with_overwrite_yes(self, _): - # self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + @patch("builtins.input", side_effect=["y"]) + def test_ask_to_proceed_with_overwrite_yes(self, _): + self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) @patch("builtins.input", side_effect=["n"]) def test_ask_to_proceed_with_overwrite_no(self, _): @@ -50,11 +48,9 @@ def test_print_msg_non_interactive(self, mock_logging): io_utils.print_msg("Hello") mock_logging.assert_called_once_with("Hello") - # TODO used to work but now it doesn't afte - # commit 77629fee45aeb7deb4c64b6aeec8975ba08e1304 - # @patch("builtins.input", side_effect=["invalid", "invalid", "y"]) - # def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _): - # self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + @patch("builtins.input", side_effect=["invalid", "invalid", "y"]) + def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _): + self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) @patch("builtins.input", side_effect=["invalid", "n"]) def test_ask_to_proceed_with_overwrite_invalid_then_no(self, _): From 4f51acfb440d2b92f9f235f3c8628843313144ac Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 14:03:16 +0000 Subject: [PATCH 06/39] Increase test coverage in `saving` --- keras_core/saving/saving_api_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index ff686d192..be3599130 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -6,6 +6,7 @@ from keras_core import layers from keras_core.models import Sequential from keras_core.saving import saving_api +from keras_core.utils import io_utils class SaveModelTests(unittest.TestCase): @@ -34,7 +35,7 @@ def test_invalid_save_format(self): def test_overwrite_prompt(self): original_mtime = os.path.getmtime(self.filepath) - saving_api.io_utils.ask_to_proceed_with_overwrite = lambda x: False + io_utils.ask_to_proceed_with_overwrite = lambda x: False saving_api.save_model(self.model, self.filepath, overwrite=False) new_mtime = os.path.getmtime(self.filepath) self.assertEqual(original_mtime, new_mtime) From f969e820dc916da558c0cd756ab13c0c374de711 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sat, 16 Sep 2023 14:17:59 +0000 Subject: [PATCH 07/39] Increase test coverage --- keras_core/saving/saving_api_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index be3599130..e9dc1aea3 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -6,7 +6,6 @@ from keras_core import layers from keras_core.models import Sequential from keras_core.saving import saving_api -from keras_core.utils import io_utils class SaveModelTests(unittest.TestCase): @@ -33,13 +32,6 @@ def test_invalid_save_format(self): ): saving_api.save_model(self.model, "model.txt", save_format=True) - def test_overwrite_prompt(self): - original_mtime = os.path.getmtime(self.filepath) - io_utils.ask_to_proceed_with_overwrite = lambda x: False - saving_api.save_model(self.model, self.filepath, overwrite=False) - new_mtime = os.path.getmtime(self.filepath) - self.assertEqual(original_mtime, new_mtime) - def test_unsupported_arguments(self): with self.assertRaises(ValueError): saving_api.save_model(self.model, self.filepath, random_arg=True) From dde85cf47fbfb855e2ffb49f424ee0410d4b54fc Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:30:38 +0000 Subject: [PATCH 08/39] refines the logic `os.makedirs` +Increase tests --- keras_core/legacy/saving/legacy_h5_format.py | 2 +- keras_core/saving/saving_api_test.py | 33 +++++++------------- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/keras_core/legacy/saving/legacy_h5_format.py b/keras_core/legacy/saving/legacy_h5_format.py index ca5660d6a..863f8c7ac 100644 --- a/keras_core/legacy/saving/legacy_h5_format.py +++ b/keras_core/legacy/saving/legacy_h5_format.py @@ -39,7 +39,7 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): dirpath = os.path.dirname(filepath) if dirpath and not os.path.exists(dirpath): - os.makedirs(dirpath) + os.makedirs(dirpath, exist_ok=True) f = h5py.File(filepath, mode="w") opened_new_file = True diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index e9dc1aea3..0e2ceeebb 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -1,6 +1,5 @@ import os import unittest - import numpy as np from keras_core import layers @@ -17,6 +16,8 @@ def setUp(self): ], ) self.filepath = "test_model.keras" + if os.path.exists(self.filepath): + os.remove(self.filepath) saving_api.save_model(self.model, self.filepath) def test_basic_saving(self): @@ -33,7 +34,9 @@ def test_invalid_save_format(self): saving_api.save_model(self.model, "model.txt", save_format=True) def test_unsupported_arguments(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, "The following argument\(s\) are not supported" + ): saving_api.save_model(self.model, self.filepath, random_arg=True) def test_save_h5_format(self): @@ -43,7 +46,9 @@ def test_save_h5_format(self): os.remove(filepath_h5) # Cleanup def test_save_unsupported_extension(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, "Invalid filepath extension for saving" + ): saving_api.save_model(self.model, "model.png") def tearDown(self): @@ -70,11 +75,11 @@ def test_basic_load(self): ) def test_load_unsupported_format(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "File format not supported"): saving_api.load_model("model.pkl") def test_load_keras_not_zip(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "File not found"): saving_api.load_model("not_a_zip.keras") def test_load_h5_format(self): @@ -85,7 +90,7 @@ def test_load_h5_format(self): self.assertTrue( np.allclose(self.model.predict(x), loaded_model.predict(x)) ) - os.remove(filepath_h5) # Cleanup + os.remove(filepath_h5) def tearDown(self): if os.path.exists(self.filepath): @@ -110,22 +115,6 @@ def test_load_keras_weights(self): for orig, loaded in zip(original_weights, loaded_weights): self.assertTrue(np.array_equal(orig, loaded)) - def test_load_unsupported_format(self): - with self.assertRaises(ValueError): - self.model.load_weights("weights.pkl") - - def test_load_keras_format_weights(self): - filepath_keras = "test_weights.weights.h5" - self.model.save_weights(filepath_keras) - self.model.load_weights(filepath_keras) - os.remove(filepath_keras) # Cleanup - - def test_load_h5_format_weights(self): - filepath_h5 = "test_weights.weights.h5" - self.model.save_weights(filepath_h5) - self.model.load_weights(filepath_h5) - os.remove(filepath_h5) # Cleanup - def tearDown(self): filepath = "test_weights.weights.h5" if os.path.exists(filepath): From d5c9f440318c2b916f2c93cb0d02bf7c726e19b2 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:37:15 +0000 Subject: [PATCH 09/39] Increase test coverage --- keras_core/saving/saving_api_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index 0e2ceeebb..7b6c5af5c 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -1,5 +1,6 @@ import os import unittest + import numpy as np from keras_core import layers From 1b9a759691f6016c8f0cbce24d2aff06ac7b26ec Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 06:52:50 +0000 Subject: [PATCH 10/39] Increase test coverage --- keras_core/saving/saving_api_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index 7b6c5af5c..ce7c468da 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -36,7 +36,7 @@ def test_invalid_save_format(self): def test_unsupported_arguments(self): with self.assertRaisesRegex( - ValueError, "The following argument\(s\) are not supported" + ValueError, r"The following argument\(s\) are not supported" ): saving_api.save_model(self.model, self.filepath, random_arg=True) From 7c6dbf3ab577956000cf7063fd9a92682ae8798b Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:39:00 +0000 Subject: [PATCH 11/39] More tests file_utils_test.py+fix bug `rmtree` --- keras_core/utils/file_utils.py | 2 +- keras_core/utils/file_utils_test.py | 133 +++++++++++++++++++++++++++- 2 files changed, 133 insertions(+), 2 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index a00ba3422..4788d618b 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -445,7 +445,7 @@ def rmtree(path): return gfile.rmtree(path) else: _raise_if_no_gfile(path) - return shutil.rmtree + return shutil.rmtree(path) def listdir(path): diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 1136e8ed8..9111db542 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -1,12 +1,15 @@ import os import tarfile import urllib -import zipfile + from keras_core.testing import test_case from keras_core.utils import file_utils +import zipfile + + class TestGetFile(test_case.TestCase): def test_get_file_and_validate_it(self): """Tests get_file from a url, plus extraction and validation.""" @@ -167,3 +170,131 @@ def test_get_file_with_failed_integrity_check(self): ValueError, "Incomplete or corrupted file.*" ): _ = file_utils.get_file("test.txt", origin, file_hash=hashval) + + def test_is_remote_path(self): + self.assertTrue(file_utils.is_remote_path("gs://bucket/path")) + self.assertTrue(file_utils.is_remote_path("http://example.com/path")) + self.assertFalse(file_utils.is_remote_path("/local/path")) + self.assertFalse(file_utils.is_remote_path("./relative/path")) + + def test_exists(self): + temp_dir = self.get_temp_dir() + file_path = os.path.join(temp_dir, "test_exists.txt") + + with open(file_path, "w") as f: + f.write("test") + + self.assertTrue(file_utils.exists(file_path)) + self.assertFalse( + file_utils.exists(os.path.join(temp_dir, "non_existent.txt")) + ) + + def test_file_open_read(self): + temp_dir = self.get_temp_dir() + file_path = os.path.join(temp_dir, "test_file.txt") + content = "test content" + + with open(file_path, "w") as f: + f.write(content) + + with file_utils.File(file_path, "r") as f: + self.assertEqual(f.read(), content) + + def test_file_open_write(self): + temp_dir = self.get_temp_dir() + file_path = os.path.join(temp_dir, "test_file_write.txt") + content = "test write content" + + with file_utils.File(file_path, "w") as f: + f.write(content) + + with open(file_path, "r") as f: + self.assertEqual(f.read(), content) + + def test_isdir(self): + temp_dir = self.get_temp_dir() + self.assertTrue(file_utils.isdir(temp_dir)) + + file_path = os.path.join(temp_dir, "test_isdir.txt") + with open(file_path, "w") as f: + f.write("test") + self.assertFalse(file_utils.isdir(file_path)) + + def test_join_simple(self): + self.assertEqual(file_utils.join("/path", "to", "dir"), "/path/to/dir") + + def test_join_single_directory(self): + self.assertEqual(file_utils.join("/path"), "/path") + + def setUp(self): + self.temp_dir = self.get_temp_dir() + self.file_path = os.path.join(self.temp_dir, "sample_file.txt") + with open(self.file_path, "w") as f: + f.write("Sample content") + + def test_is_remote_path(self): + self.assertTrue(file_utils.is_remote_path("gcs://bucket/path")) + self.assertFalse(file_utils.is_remote_path("/local/path")) + + def test_exists(self): + self.assertTrue(file_utils.exists(self.file_path)) + self.assertFalse(file_utils.exists("/path/that/does/not/exist")) + + def test_isdir(self): + self.assertTrue(file_utils.isdir(self.temp_dir)) + self.assertFalse(file_utils.isdir(self.file_path)) + + def test_listdir(self): + content = file_utils.listdir(self.temp_dir) + self.assertIn("sample_file.txt", content) + + def test_makedirs_and_rmtree(self): + new_dir = os.path.join(self.temp_dir, "new_directory") + file_utils.makedirs(new_dir) + self.assertTrue(os.path.isdir(new_dir)) + file_utils.rmtree(new_dir) + self.assertFalse(os.path.exists(new_dir)) + + def test_copy(self): + dest_path = os.path.join(self.temp_dir, "copy_sample_file.txt") + file_utils.copy(self.file_path, dest_path) + self.assertTrue(os.path.exists(dest_path)) + with open(dest_path, "r") as f: + content = f.read() + self.assertEqual(content, "Sample content") + + def test_file_open_read(self): + with file_utils.File(self.file_path, "r") as f: + content = f.read() + self.assertEqual(content, "Sample content") + + def test_file_open_write(self): + with file_utils.File(self.file_path, "w") as f: + f.write("New content") + with open(self.file_path, "r") as f: + content = f.read() + self.assertEqual(content, "New content") + + def test_remove_sub_directory(self): + parent_dir = os.path.join(self.get_temp_dir(), "parent_directory") + child_dir = os.path.join(parent_dir, "child_directory") + file_utils.makedirs(child_dir) + file_utils.rmtree(parent_dir) + self.assertFalse(os.path.exists(parent_dir)) + self.assertFalse(os.path.exists(child_dir)) + + def test_remove_files_inside_directory(self): + dir_path = os.path.join(self.get_temp_dir(), "test_directory") + file_path = os.path.join(dir_path, "test.txt") + file_utils.makedirs(dir_path) + with open(file_path, "w") as f: + f.write("Test content") + file_utils.rmtree(dir_path) + self.assertFalse(os.path.exists(dir_path)) + self.assertFalse(os.path.exists(file_path)) + + def test_handle_complex_paths(self): + complex_dir = os.path.join(self.get_temp_dir(), "complex dir@#%&!") + file_utils.makedirs(complex_dir) + file_utils.rmtree(complex_dir) + self.assertFalse(os.path.exists(complex_dir)) From bce3dc616c56e74ebf2a73a455400b306fc25c1f Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:41:31 +0000 Subject: [PATCH 12/39] More tests `file_utils_test` + fix bug `rmtree` --- keras_core/utils/file_utils_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 9111db542..8606b8ed9 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -1,15 +1,12 @@ import os import tarfile import urllib - +import zipfile from keras_core.testing import test_case from keras_core.utils import file_utils -import zipfile - - class TestGetFile(test_case.TestCase): def test_get_file_and_validate_it(self): """Tests get_file from a url, plus extraction and validation.""" From 50d2f9a686020cde46d0e5a80264ed9bf6909776 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 09:02:33 +0000 Subject: [PATCH 13/39] More tests file_utils_test + fix bug rmtree --- keras_core/utils/file_utils_test.py | 43 +++++------------------------ 1 file changed, 7 insertions(+), 36 deletions(-) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 8606b8ed9..e9535b9ea 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -8,14 +8,18 @@ class TestGetFile(test_case.TestCase): + def setUp(self): + self.temp_dir = self.get_temp_dir() + self.file_path = os.path.join(self.temp_dir, "sample_file.txt") + with open(self.file_path, "w") as f: + f.write("Sample content") + def test_get_file_and_validate_it(self): - """Tests get_file from a url, plus extraction and validation.""" dest_dir = self.get_temp_dir() orig_dir = self.get_temp_dir() - text_file_path = os.path.join(orig_dir, "test.txt") - zip_file_path = os.path.join(orig_dir, "test.zip") tar_file_path = os.path.join(orig_dir, "test.tar.gz") + zip_file_path = os.path.join(orig_dir, "test.zip") with open(text_file_path, "w") as text_file: text_file.write("Float like a butterfly, sting like a bee.") @@ -108,7 +112,6 @@ def test_get_file_and_validate_it(self): _ = file_utils.get_file() def test_get_file_with_tgz_extension(self): - """Tests get_file from a url, plus extraction and validation.""" dest_dir = self.get_temp_dir() orig_dir = dest_dir @@ -133,7 +136,6 @@ def test_get_file_with_tgz_extension(self): self.assertTrue(os.path.exists(path)) def test_get_file_with_integrity_check(self): - """Tests get_file with validation before download.""" orig_dir = self.get_temp_dir() file_path = os.path.join(orig_dir, "test.txt") @@ -150,7 +152,6 @@ def test_get_file_with_integrity_check(self): self.assertTrue(os.path.exists(path)) def test_get_file_with_failed_integrity_check(self): - """Tests get_file with validation before download.""" orig_dir = self.get_temp_dir() file_path = os.path.join(orig_dir, "test.txt") @@ -223,24 +224,6 @@ def test_join_simple(self): def test_join_single_directory(self): self.assertEqual(file_utils.join("/path"), "/path") - def setUp(self): - self.temp_dir = self.get_temp_dir() - self.file_path = os.path.join(self.temp_dir, "sample_file.txt") - with open(self.file_path, "w") as f: - f.write("Sample content") - - def test_is_remote_path(self): - self.assertTrue(file_utils.is_remote_path("gcs://bucket/path")) - self.assertFalse(file_utils.is_remote_path("/local/path")) - - def test_exists(self): - self.assertTrue(file_utils.exists(self.file_path)) - self.assertFalse(file_utils.exists("/path/that/does/not/exist")) - - def test_isdir(self): - self.assertTrue(file_utils.isdir(self.temp_dir)) - self.assertFalse(file_utils.isdir(self.file_path)) - def test_listdir(self): content = file_utils.listdir(self.temp_dir) self.assertIn("sample_file.txt", content) @@ -260,18 +243,6 @@ def test_copy(self): content = f.read() self.assertEqual(content, "Sample content") - def test_file_open_read(self): - with file_utils.File(self.file_path, "r") as f: - content = f.read() - self.assertEqual(content, "Sample content") - - def test_file_open_write(self): - with file_utils.File(self.file_path, "w") as f: - f.write("New content") - with open(self.file_path, "r") as f: - content = f.read() - self.assertEqual(content, "New content") - def test_remove_sub_directory(self): parent_dir = os.path.join(self.get_temp_dir(), "parent_directory") child_dir = os.path.join(parent_dir, "child_directory") From fe0e3876a620bb3b43b996aee1826a28bc0169d3 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 09:36:54 +0000 Subject: [PATCH 14/39] Increase test coverage --- keras_core/utils/file_utils_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index e9535b9ea..935d9e695 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -1,5 +1,6 @@ import os import tarfile +import unittest import urllib import zipfile @@ -266,3 +267,25 @@ def test_handle_complex_paths(self): file_utils.makedirs(complex_dir) file_utils.rmtree(complex_dir) self.assertFalse(os.path.exists(complex_dir)) + + +class TestFilterSafePaths(unittest.TestCase): + def setUp(self): + # Assuming the temp directory is the base dir for our tests + self.base_dir = os.path.join(os.getcwd(), "temp_dir") + os.makedirs(self.base_dir, exist_ok=True) + self.tar_path = os.path.join(self.base_dir, "test.tar") + + def tearDown(self): + os.remove(self.tar_path) + os.rmdir(self.base_dir) + + def test_member_within_base_dir(self): + with tarfile.open(self.tar_path, "w") as tar: + tar.add( + __file__, arcname="safe_path.txt" + ) # Adds this test file to the tar archive + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "safe_path.txt") From c72e412d92618cfa1ce65a3644ef4954e8fe31fc Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 14:34:27 +0000 Subject: [PATCH 15/39] add tests to `lambda_callback_test` --- keras_core/callbacks/lambda_callback_test.py | 60 ++++++++++++++++++-- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/keras_core/callbacks/lambda_callback_test.py b/keras_core/callbacks/lambda_callback_test.py index d2f746f70..e043c765d 100644 --- a/keras_core/callbacks/lambda_callback_test.py +++ b/keras_core/callbacks/lambda_callback_test.py @@ -13,6 +13,7 @@ class LambdaCallbackTest(testing.TestCase): @pytest.mark.requires_trainable_backend def test_LambdaCallback(self): + """Test standard LambdaCallback functionalities with training.""" BATCH_SIZE = 4 model = Sequential( [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] @@ -40,13 +41,18 @@ def test_LambdaCallback(self): epochs=5, verbose=0, ) - self.assertTrue(any("on_train_begin" in log for log in logs.output)) - self.assertTrue(any("on_epoch_begin" in log for log in logs.output)) - self.assertTrue(any("on_epoch_end" in log for log in logs.output)) - self.assertTrue(any("on_train_end" in log for log in logs.output)) + self.assertTrue + (any("on_train_begin" in log for log in logs.output)) + self.assertTrue + (any("on_epoch_begin" in log for log in logs.output)) + self.assertTrue + (any("on_epoch_end" in log for log in logs.output)) + self.assertTrue + (any("on_train_end" in log for log in logs.output)) @pytest.mark.requires_trainable_backend def test_LambdaCallback_with_batches(self): + """Test LambdaCallback's behavior with batch-level callbacks.""" BATCH_SIZE = 4 model = Sequential( [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] @@ -83,6 +89,7 @@ def test_LambdaCallback_with_batches(self): @pytest.mark.requires_trainable_backend def test_LambdaCallback_with_kwargs(self): + """Test LambdaCallback's behavior with custom defined callback.""" BATCH_SIZE = 4 model = Sequential( [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] @@ -96,7 +103,6 @@ def test_LambdaCallback_with_kwargs(self): x, y, batch_size=BATCH_SIZE, epochs=1, verbose=0 ) # Train briefly for evaluation to work. - # Replacing lambda with a proper function definition def custom_on_test_begin(logs): logging.warning("custom_on_test_begin_executed") @@ -117,3 +123,47 @@ def custom_on_test_begin(logs): for log in logs.output ) ) + + @pytest.mark.requires_trainable_backend + def test_LambdaCallback_no_args(self): + """Test initializing LambdaCallback without any arguments.""" + lambda_callback = callbacks.LambdaCallback() + self.assertIsInstance(lambda_callback, callbacks.LambdaCallback) + + @pytest.mark.requires_trainable_backend + def test_LambdaCallback_with_additional_kwargs(self): + """Test initializing LambdaCallback with non-predefined kwargs.""" + + def custom_callback(logs): + pass + + lambda_callback = callbacks.LambdaCallback( + custom_method=custom_callback + ) + self.assertTrue(hasattr(lambda_callback, "custom_method")) + + @pytest.mark.requires_trainable_backend + def test_LambdaCallback_during_prediction(self): + """Test LambdaCallback's functionality during model prediction.""" + BATCH_SIZE = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + + def custom_on_predict_begin(logs): + logging.warning("on_predict_begin_executed") + + lambda_callback = callbacks.LambdaCallback( + on_predict_begin=custom_on_predict_begin + ) + with self.assertLogs(level="WARNING") as logs: + model.predict( + x, batch_size=BATCH_SIZE, callbacks=[lambda_callback], verbose=0 + ) + self.assertTrue( + any("on_predict_begin_executed" in log for log in logs.output) + ) From 995a3361b32e4257038ed9c2de6157be7c4d5c63 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 16:00:17 +0000 Subject: [PATCH 16/39] Add tests in file_utils_test.py --- keras_core/utils/file_utils_test.py | 230 +++++++++++++++++++++++++--- 1 file changed, 208 insertions(+), 22 deletions(-) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 935d9e695..893186778 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -1,5 +1,8 @@ import os +import pathlib +import shutil import tarfile +import tempfile import unittest import urllib import zipfile @@ -8,6 +11,211 @@ from keras_core.utils import file_utils +class PathToStringTest(test_case.TestCase): + def test_path_to_string_with_string_path(self): + path = "/path/to/file.txt" + string_path = file_utils.path_to_string(path) + self.assertEqual(string_path, path) + + def test_path_to_string_with_PathLike_object(self): + path = pathlib.Path("/path/to/file.txt") + string_path = file_utils.path_to_string(path) + self.assertEqual(string_path, str(path)) + + def test_path_to_string_with_non_string_typed_path_object(self): + class NonStringTypedPathObject: + def __fspath__(self): + return "/path/to/file.txt" + + path = NonStringTypedPathObject() + string_path = file_utils.path_to_string(path) + self.assertEqual(string_path, "/path/to/file.txt") + + def test_path_to_string_with_none_path(self): + string_path = file_utils.path_to_string(None) + self.assertEqual(string_path, None) + + +class ResolvePathTest(test_case.TestCase): + def test_resolve_path_with_absolute_path(self): + path = "/path/to/file.txt" + resolved_path = file_utils.resolve_path(path) + self.assertEqual(resolved_path, os.path.realpath(os.path.abspath(path))) + + def test_resolve_path_with_relative_path(self): + path = "./file.txt" + resolved_path = file_utils.resolve_path(path) + self.assertEqual(resolved_path, os.path.realpath(os.path.abspath(path))) + + +class IsPathInDirTest(test_case.TestCase): + def test_is_path_in_dir_with_absolute_paths(self): + base_dir = "/path/to/base_dir" + path = "/path/to/base_dir/file.txt" + self.assertTrue(file_utils.is_path_in_dir(path, base_dir)) + + +class IsLinkInDirTest(test_case.TestCase): + def setUp(self): + # This setup method runs before each test. + # Ensuring both base directories are clean before the tests are run. + self._cleanup("test_path/to/base_dir") + self._cleanup("./base_dir") + + def _cleanup(self, base_dir): + if os.path.exists(base_dir): + shutil.rmtree(base_dir) + + def test_is_link_in_dir_with_absolute_paths(self): + base_dir = "test_path/to/base_dir" + link_path = os.path.join(base_dir, "symlink") + target_path = os.path.join(base_dir, "file.txt") + + # Create the base_dir directory if it does not exist. + os.makedirs(base_dir, exist_ok=True) + + # Create the file.txt file. + with open(target_path, "w") as f: + f.write("Hello, world!") + + os.symlink(target_path, link_path) + + # Creating a stat_result-like object with a name attribute + info = os.lstat(link_path) + info = type( + "stat_with_name", + (object,), + { + "name": os.path.basename(link_path), + "linkname": os.readlink(link_path), + }, + ) + + self.assertTrue(file_utils.is_link_in_dir(info, base_dir)) + + def test_is_link_in_dir_with_relative_paths(self): + base_dir = "./base_dir" + link_path = os.path.join(base_dir, "symlink") + target_path = os.path.join(base_dir, "file.txt") + + # Create the base_dir directory if it does not exist. + os.makedirs(base_dir, exist_ok=True) + + # Create the file.txt file. + with open(target_path, "w") as f: + f.write("Hello, world!") + + os.symlink(target_path, link_path) + + # Creating a stat_result-like object with a name attribute + info = os.lstat(link_path) + info = type( + "stat_with_name", + (object,), + { + "name": os.path.basename(link_path), + "linkname": os.readlink(link_path), + }, + ) + + self.assertTrue(file_utils.is_link_in_dir(info, base_dir)) + + def tearDown(self): + # This method will be called after each test, ensuring we leave no leftovers. + self._cleanup("test_path/to/base_dir") + self._cleanup("./base_dir") + + +class TestFilterSafePaths(test_case.TestCase): + def setUp(self): + # Assuming the temp directory is the base dir for our tests + self.base_dir = os.path.join(os.getcwd(), "temp_dir") + os.makedirs(self.base_dir, exist_ok=True) + self.tar_path = os.path.join(self.base_dir, "test.tar") + + def tearDown(self): + os.remove(self.tar_path) + os.rmdir(self.base_dir) + + def test_member_within_base_dir(self): + with tarfile.open(self.tar_path, "w") as tar: + tar.add( + __file__, arcname="safe_path.txt" + ) # Adds this test file to the tar archive + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "safe_path.txt") + + +class ExtractArchiveTest(test_case.TestCase): + def setUp(self): + """Create temporary directories and files for testing.""" + self.temp_dir = tempfile.mkdtemp() + self.file_content = "Hello, world!" + + # Create sample files to be archived + with open(os.path.join(self.temp_dir, "sample.txt"), "w") as f: + f.write(self.file_content) + + def tearDown(self): + """Clean up temporary directories.""" + shutil.rmtree(self.temp_dir) + + def create_tar(self): + archive_path = os.path.join(self.temp_dir, "sample.tar") + with tarfile.open(archive_path, "w") as archive: + archive.add( + os.path.join(self.temp_dir, "sample.txt"), arcname="sample.txt" + ) + return archive_path + + def create_zip(self): + archive_path = os.path.join(self.temp_dir, "sample.zip") + with zipfile.ZipFile(archive_path, "w") as archive: + archive.write( + os.path.join(self.temp_dir, "sample.txt"), arcname="sample.txt" + ) + return archive_path + + def test_extract_tar(self): + archive_path = self.create_tar() + extract_path = os.path.join(self.temp_dir, "extract_tar") + result = file_utils.extract_archive(archive_path, extract_path, "tar") + self.assertTrue(result) + with open(os.path.join(extract_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + def test_extract_zip(self): + archive_path = self.create_zip() + extract_path = os.path.join(self.temp_dir, "extract_zip") + result = file_utils.extract_archive(archive_path, extract_path, "zip") + self.assertTrue(result) + with open(os.path.join(extract_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + def test_extract_auto(self): + # This will test the 'auto' functionality + tar_archive_path = self.create_tar() + zip_archive_path = self.create_zip() + + extract_tar_path = os.path.join(self.temp_dir, "extract_auto_tar") + extract_zip_path = os.path.join(self.temp_dir, "extract_auto_zip") + + self.assertTrue( + file_utils.extract_archive(tar_archive_path, extract_tar_path) + ) + self.assertTrue( + file_utils.extract_archive(zip_archive_path, extract_zip_path) + ) + + with open(os.path.join(extract_tar_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + with open(os.path.join(extract_zip_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + class TestGetFile(test_case.TestCase): def setUp(self): self.temp_dir = self.get_temp_dir() @@ -267,25 +475,3 @@ def test_handle_complex_paths(self): file_utils.makedirs(complex_dir) file_utils.rmtree(complex_dir) self.assertFalse(os.path.exists(complex_dir)) - - -class TestFilterSafePaths(unittest.TestCase): - def setUp(self): - # Assuming the temp directory is the base dir for our tests - self.base_dir = os.path.join(os.getcwd(), "temp_dir") - os.makedirs(self.base_dir, exist_ok=True) - self.tar_path = os.path.join(self.base_dir, "test.tar") - - def tearDown(self): - os.remove(self.tar_path) - os.rmdir(self.base_dir) - - def test_member_within_base_dir(self): - with tarfile.open(self.tar_path, "w") as tar: - tar.add( - __file__, arcname="safe_path.txt" - ) # Adds this test file to the tar archive - with tarfile.open(self.tar_path, "r") as tar: - members = list(file_utils.filter_safe_paths(tar.getmembers())) - self.assertEqual(len(members), 1) - self.assertEqual(members[0].name, "safe_path.txt") From 1d038425da5bdf083b4676560ca8a68626455a55 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 16:15:51 +0000 Subject: [PATCH 17/39] Add tests in file_utils_test.py --- keras_core/utils/file_utils_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 893186778..b11390fbe 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -3,7 +3,6 @@ import shutil import tarfile import tempfile -import unittest import urllib import zipfile @@ -121,7 +120,6 @@ def test_is_link_in_dir_with_relative_paths(self): self.assertTrue(file_utils.is_link_in_dir(info, base_dir)) def tearDown(self): - # This method will be called after each test, ensuring we leave no leftovers. self._cleanup("test_path/to/base_dir") self._cleanup("./base_dir") From c12e93336ac239aa9dd556edd36577043b3b637d Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 16:53:06 +0000 Subject: [PATCH 18/39] Add more tests `file_utils_test` --- keras_core/utils/file_utils.py | 5 ++--- keras_core/utils/file_utils_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 4788d618b..8068c8f78 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -344,9 +344,8 @@ def hash_file(fpath, algorithm="sha256", chunk_size=65535): Args: fpath: Path to the file being validated. - algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. - The default `"auto"` detects the hash algorithm in use. - chunk_size: Bytes to read at a time, important for large files. + algorithm: Hash algorithm, one of `"sha256"` or `"md5"`. + chunk_size: Bytes to read at a time, important for large files Returns: The file hash. diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index b11390fbe..e28ff0acc 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -214,6 +214,33 @@ def test_extract_auto(self): self.assertEqual(f.read(), self.file_content) +class TestHashFile(test_case.TestCase): + def setUp(self): + self.test_content = b"Hello, World!" + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.temp_file.write(self.test_content) + self.temp_file.close() + + def tearDown(self): + os.remove(self.temp_file.name) + + def test_hash_file_sha256(self): + expected_sha256 = ( + "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + ) + calculated_sha256 = file_utils.hash_file( + self.temp_file.name, algorithm="sha256" + ) + self.assertEqual(expected_sha256, calculated_sha256) + + def test_hash_file_md5(self): + expected_md5 = "65a8e27d8879283831b664bd8b7f0ad4" + calculated_md5 = file_utils.hash_file( + self.temp_file.name, algorithm="md5" + ) + self.assertEqual(expected_md5, calculated_md5) + + class TestGetFile(test_case.TestCase): def setUp(self): self.temp_dir = self.get_temp_dir() From b8d762644bc4e805ce1bd9ead5a6e94a183a2e04 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 17:18:39 +0000 Subject: [PATCH 19/39] add class TestValidateFile --- keras_core/utils/file_utils_test.py | 45 +++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index e28ff0acc..a95eb5a86 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -241,6 +241,51 @@ def test_hash_file_md5(self): self.assertEqual(expected_md5, calculated_md5) +class TestValidateFile(test_case.TestCase): + def setUp(self): + self.tmp_file = tempfile.NamedTemporaryFile(delete=False) + self.tmp_file.write(b"Hello, World!") + self.tmp_file.close() + + self.sha256_hash = ( + "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + ) + self.md5_hash = "65a8e27d8879283831b664bd8b7f0ad4" + + def test_validate_file_sha256(self): + self.assertTrue( + file_utils.validate_file( + self.tmp_file.name, self.sha256_hash, "sha256" + ) + ) + + def test_validate_file_md5(self): + self.assertTrue( + file_utils.validate_file(self.tmp_file.name, self.md5_hash, "md5") + ) + + def test_validate_file_auto_sha256(self): + self.assertTrue( + file_utils.validate_file( + self.tmp_file.name, self.sha256_hash, "auto" + ) + ) + + def test_validate_file_auto_md5(self): + self.assertTrue( + file_utils.validate_file(self.tmp_file.name, self.md5_hash, "auto") + ) + + def test_validate_file_wrong_hash(self): + wrong_hash = "deadbeef" * 8 # + self.assertFalse( + file_utils.validate_file(self.tmp_file.name, wrong_hash, "sha256") + ) + + def tearDown(self): + os.remove(self.tmp_file.name) + + class TestGetFile(test_case.TestCase): def setUp(self): self.temp_dir = self.get_temp_dir() From ca4e721201ee361fee54a555c32a87f0279d1364 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Sun, 17 Sep 2023 17:36:13 +0000 Subject: [PATCH 20/39] Add tests for `TestIsRemotePath` --- keras_core/utils/file_utils_test.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index a95eb5a86..590448d47 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -286,6 +286,44 @@ def tearDown(self): os.remove(self.tmp_file.name) +class TestIsRemotePath(test_case.TestCase): + def test_gcs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) + + def test_cns_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cns/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/cns/another/directory/")) + + def test_cfs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cfs/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/cfs/another/directory/")) + + def test_http_remote_path(self): + self.assertTrue( + file_utils.is_remote_path("http://example.com/path/to/file.txt") + ) + self.assertTrue( + file_utils.is_remote_path("https://secure.example.com/directory/") + ) + self.assertTrue( + file_utils.is_remote_path("ftp://files.example.com/somefile.txt") + ) + + def test_non_remote_paths(self): + self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) + self.assertFalse( + file_utils.is_remote_path("C:\\local\\path\\on\\windows\\file.txt") + ) + self.assertFalse(file_utils.is_remote_path("~/relative/path/")) + self.assertFalse(file_utils.is_remote_path("./another/relative/path")) + + def test_edge_cases(self): + self.assertFalse(file_utils.is_remote_path("")) + self.assertFalse(file_utils.is_remote_path(None)) + self.assertFalse(file_utils.is_remote_path(12345)) + + class TestGetFile(test_case.TestCase): def setUp(self): self.temp_dir = self.get_temp_dir() From 19f5f6f6ca04a93dd980064659b62ce617f2271c Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:34:30 +0000 Subject: [PATCH 21/39] Add tests in file_utils_test.py --- invalid.txt | 1 + keras_core/utils/file_utils_test.py | 279 +++++++++++++++++----------- 2 files changed, 170 insertions(+), 110 deletions(-) create mode 100644 invalid.txt diff --git a/invalid.txt b/invalid.txt new file mode 100644 index 000000000..e466dcbd8 --- /dev/null +++ b/invalid.txt @@ -0,0 +1 @@ +invalid \ No newline at end of file diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 590448d47..8f74e7e0b 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -1,3 +1,4 @@ +import hashlib import os import pathlib import shutil @@ -5,6 +6,7 @@ import tempfile import urllib import zipfile +from unittest.mock import patch from keras_core.testing import test_case from keras_core.utils import file_utils @@ -56,8 +58,6 @@ def test_is_path_in_dir_with_absolute_paths(self): class IsLinkInDirTest(test_case.TestCase): def setUp(self): - # This setup method runs before each test. - # Ensuring both base directories are clean before the tests are run. self._cleanup("test_path/to/base_dir") self._cleanup("./base_dir") @@ -124,27 +124,58 @@ def tearDown(self): self._cleanup("./base_dir") -class TestFilterSafePaths(test_case.TestCase): +class FilterSafePathsTest(test_case.TestCase): def setUp(self): - # Assuming the temp directory is the base dir for our tests self.base_dir = os.path.join(os.getcwd(), "temp_dir") os.makedirs(self.base_dir, exist_ok=True) self.tar_path = os.path.join(self.base_dir, "test.tar") def tearDown(self): os.remove(self.tar_path) - os.rmdir(self.base_dir) + shutil.rmtree(self.base_dir) def test_member_within_base_dir(self): + """Test a member within the base directory.""" with tarfile.open(self.tar_path, "w") as tar: - tar.add( - __file__, arcname="safe_path.txt" - ) # Adds this test file to the tar archive + tar.add(__file__, arcname="safe_path.txt") with tarfile.open(self.tar_path, "r") as tar: members = list(file_utils.filter_safe_paths(tar.getmembers())) self.assertEqual(len(members), 1) self.assertEqual(members[0].name, "safe_path.txt") + def test_symlink_within_base_dir(self): + """Test a symlink pointing within the base directory.""" + symlink_path = os.path.join(self.base_dir, "symlink.txt") + target_path = os.path.join(self.base_dir, "target.txt") + with open(target_path, "w") as f: + f.write("target") + os.symlink(target_path, symlink_path) + with tarfile.open(self.tar_path, "w") as tar: + tar.add(symlink_path, arcname="symlink.txt") + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "symlink.txt") + os.remove(symlink_path) + os.remove(target_path) + + def test_invalid_path_warning(self): + """Test warning for an invalid path during archive extraction.""" + invalid_path = os.path.join(os.getcwd(), "invalid.txt") + with open(invalid_path, "w") as f: + f.write("invalid") + with tarfile.open(self.tar_path, "w") as tar: + tar.add( + invalid_path, arcname="../../invalid.txt" + ) # Path intended to be outside of base dir + with tarfile.open(self.tar_path, "r") as tar: + with patch("warnings.warn") as mock_warn: + _ = list(file_utils.filter_safe_paths(tar.getmembers())) + mock_warn.assert_called_with( + "Skipping invalid path during archive extraction: '../../invalid.txt'.", + stacklevel=2, + ) + class ExtractArchiveTest(test_case.TestCase): def setUp(self): @@ -213,8 +244,19 @@ def test_extract_auto(self): with open(os.path.join(extract_zip_path, "sample.txt"), "r") as f: self.assertEqual(f.read(), self.file_content) + def test_non_existent_file(self): + extract_path = os.path.join(self.temp_dir, "non_existent") + with self.assertRaises(FileNotFoundError): + file_utils.extract_archive("non_existent.tar", extract_path) + + def test_archive_format_none(self): + archive_path = self.create_tar() + extract_path = os.path.join(self.temp_dir, "none_format") + result = file_utils.extract_archive(archive_path, extract_path, None) + self.assertFalse(result) -class TestHashFile(test_case.TestCase): + +class HashFileTest(test_case.TestCase): def setUp(self): self.test_content = b"Hello, World!" self.temp_file = tempfile.NamedTemporaryFile(delete=False) @@ -225,6 +267,7 @@ def tearDown(self): os.remove(self.temp_file.name) def test_hash_file_sha256(self): + """Test SHA256 hashing of a file.""" expected_sha256 = ( "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" ) @@ -234,6 +277,7 @@ def test_hash_file_sha256(self): self.assertEqual(expected_sha256, calculated_sha256) def test_hash_file_md5(self): + """Test MD5 hashing of a file.""" expected_md5 = "65a8e27d8879283831b664bd8b7f0ad4" calculated_md5 = file_utils.hash_file( self.temp_file.name, algorithm="md5" @@ -253,6 +297,7 @@ def setUp(self): self.md5_hash = "65a8e27d8879283831b664bd8b7f0ad4" def test_validate_file_sha256(self): + """Validate SHA256 hash of a file.""" self.assertTrue( file_utils.validate_file( self.tmp_file.name, self.sha256_hash, "sha256" @@ -260,11 +305,13 @@ def test_validate_file_sha256(self): ) def test_validate_file_md5(self): + """Validate MD5 hash of a file.""" self.assertTrue( file_utils.validate_file(self.tmp_file.name, self.md5_hash, "md5") ) def test_validate_file_auto_sha256(self): + """Auto-detect and validate SHA256 hash.""" self.assertTrue( file_utils.validate_file( self.tmp_file.name, self.sha256_hash, "auto" @@ -272,12 +319,14 @@ def test_validate_file_auto_sha256(self): ) def test_validate_file_auto_md5(self): + """Auto-detect and validate MD5 hash.""" self.assertTrue( file_utils.validate_file(self.tmp_file.name, self.md5_hash, "auto") ) def test_validate_file_wrong_hash(self): - wrong_hash = "deadbeef" * 8 # + """Test validation with incorrect hash.""" + wrong_hash = "deadbeef" * 8 self.assertFalse( file_utils.validate_file(self.tmp_file.name, wrong_hash, "sha256") ) @@ -286,7 +335,29 @@ def tearDown(self): os.remove(self.tmp_file.name) -class TestIsRemotePath(test_case.TestCase): +class ResolveHasherTest(test_case.TestCase): + def test_resolve_hasher_sha256(self): + """Test resolving hasher for sha256 algorithm.""" + hasher = file_utils.resolve_hasher("sha256") + self.assertIsInstance(hasher, type(hashlib.sha256())) + + def test_resolve_hasher_auto_sha256(self): + """Auto-detect and resolve hasher for sha256.""" + hasher = file_utils.resolve_hasher("auto", file_hash="a" * 64) + self.assertIsInstance(hasher, type(hashlib.sha256())) + + def test_resolve_hasher_auto_md5(self): + """Auto-detect and resolve hasher for md5.""" + hasher = file_utils.resolve_hasher("auto", file_hash="a" * 32) + self.assertIsInstance(hasher, type(hashlib.md5())) + + def test_resolve_hasher_default(self): + """Resolve hasher with a random algorithm value.""" + hasher = file_utils.resolve_hasher("random_value") + self.assertIsInstance(hasher, type(hashlib.md5())) + + +class IsRemotePath(test_case.TestCase): def test_gcs_remote_path(self): self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) @@ -324,122 +395,48 @@ def test_edge_cases(self): self.assertFalse(file_utils.is_remote_path(12345)) -class TestGetFile(test_case.TestCase): +class GetFileTest(test_case.TestCase): def setUp(self): + """Set up temporary directories and sample files.""" self.temp_dir = self.get_temp_dir() self.file_path = os.path.join(self.temp_dir, "sample_file.txt") with open(self.file_path, "w") as f: f.write("Sample content") - def test_get_file_and_validate_it(self): + def test_valid_tar_extraction(self): + """Test valid tar.gz extraction and hash validation.""" dest_dir = self.get_temp_dir() orig_dir = self.get_temp_dir() - text_file_path = os.path.join(orig_dir, "test.txt") - tar_file_path = os.path.join(orig_dir, "test.tar.gz") - zip_file_path = os.path.join(orig_dir, "test.zip") - - with open(text_file_path, "w") as text_file: - text_file.write("Float like a butterfly, sting like a bee.") - - with tarfile.open(tar_file_path, "w:gz") as tar_file: - tar_file.add(text_file_path) - - with zipfile.ZipFile(zip_file_path, "w") as zip_file: - zip_file.write(text_file_path) - - origin = urllib.parse.urljoin( - "file://", - urllib.request.pathname2url(os.path.abspath(tar_file_path)), - ) - - path = file_utils.get_file( - "test.txt", origin, untar=True, cache_subdir=dest_dir - ) - filepath = path + ".tar.gz" - hashval_sha256 = file_utils.hash_file(filepath) - hashval_md5 = file_utils.hash_file(filepath, algorithm="md5") - path = file_utils.get_file( - "test.txt", - origin, - md5_hash=hashval_md5, - untar=True, - cache_subdir=dest_dir, - ) - path = file_utils.get_file( - filepath, - origin, - file_hash=hashval_sha256, - extract=True, - cache_subdir=dest_dir, + text_file_path, tar_file_path = self._create_tar_file(orig_dir) + self._test_file_extraction_and_validation( + dest_dir, tar_file_path, "tar.gz" ) - self.assertTrue(os.path.exists(filepath)) - self.assertTrue(file_utils.validate_file(filepath, hashval_sha256)) - self.assertTrue(file_utils.validate_file(filepath, hashval_md5)) - os.remove(filepath) - origin = urllib.parse.urljoin( - "file://", - urllib.request.pathname2url(os.path.abspath(zip_file_path)), - ) - - hashval_sha256 = file_utils.hash_file(zip_file_path) - hashval_md5 = file_utils.hash_file(zip_file_path, algorithm="md5") - path = file_utils.get_file( - "test", - origin, - md5_hash=hashval_md5, - extract=True, - cache_subdir=dest_dir, - ) - path = file_utils.get_file( - "test", - origin, - file_hash=hashval_sha256, - extract=True, - cache_subdir=dest_dir, + def test_valid_zip_extraction(self): + """Test valid zip extraction and hash validation.""" + dest_dir = self.get_temp_dir() + orig_dir = self.get_temp_dir() + text_file_path, zip_file_path = self._create_zip_file(orig_dir) + self._test_file_extraction_and_validation( + dest_dir, zip_file_path, "zip" ) - self.assertTrue(os.path.exists(path)) - self.assertTrue(file_utils.validate_file(path, hashval_sha256)) - self.assertTrue(file_utils.validate_file(path, hashval_md5)) - os.remove(path) - for file_path, extract in [ - (text_file_path, False), - (tar_file_path, True), - (zip_file_path, True), - ]: - origin = urllib.parse.urljoin( - "file://", - urllib.request.pathname2url(os.path.abspath(file_path)), - ) - hashval_sha256 = file_utils.hash_file(file_path) - path = file_utils.get_file( - origin=origin, - file_hash=hashval_sha256, - extract=extract, - cache_subdir=dest_dir, - ) - self.assertTrue(os.path.exists(path)) - self.assertTrue(file_utils.validate_file(path, hashval_sha256)) - os.remove(path) - - with self.assertRaisesRegexp( - ValueError, 'Please specify the "origin".*' - ): - _ = file_utils.get_file() - - def test_get_file_with_tgz_extension(self): + def test_valid_text_file_download(self): + """Test valid text file download and hash validation.""" dest_dir = self.get_temp_dir() - orig_dir = dest_dir - + orig_dir = self.get_temp_dir() text_file_path = os.path.join(orig_dir, "test.txt") - tar_file_path = os.path.join(orig_dir, "test.tar.gz") - with open(text_file_path, "w") as text_file: text_file.write("Float like a butterfly, sting like a bee.") + self._test_file_extraction_and_validation( + dest_dir, text_file_path, None + ) - with tarfile.open(tar_file_path, "w:gz") as tar_file: - tar_file.add(text_file_path) + def test_get_file_with_tgz_extension(self): + """Test extraction of file with .tar.gz extension.""" + dest_dir = self.get_temp_dir() + orig_dir = dest_dir + text_file_path, tar_file_path = self._create_tar_file(orig_dir) origin = urllib.parse.urljoin( "file://", @@ -453,6 +450,7 @@ def test_get_file_with_tgz_extension(self): self.assertTrue(os.path.exists(path)) def test_get_file_with_integrity_check(self): + """Test file download with integrity check.""" orig_dir = self.get_temp_dir() file_path = os.path.join(orig_dir, "test.txt") @@ -469,6 +467,7 @@ def test_get_file_with_integrity_check(self): self.assertTrue(os.path.exists(path)) def test_get_file_with_failed_integrity_check(self): + """Test file download with failed integrity check.""" orig_dir = self.get_temp_dir() file_path = os.path.join(orig_dir, "test.txt") @@ -486,6 +485,66 @@ def test_get_file_with_failed_integrity_check(self): ): _ = file_utils.get_file("test.txt", origin, file_hash=hashval) + def _create_tar_file(self, directory): + """Helper function to create a tar file.""" + text_file_path = os.path.join(directory, "test.txt") + tar_file_path = os.path.join(directory, "test.tar.gz") + with open(text_file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + + with tarfile.open(tar_file_path, "w:gz") as tar_file: + tar_file.add(text_file_path) + + return text_file_path, tar_file_path + + def _create_zip_file(self, directory): + """Helper function to create a zip file.""" + text_file_path = os.path.join(directory, "test.txt") + zip_file_path = os.path.join(directory, "test.zip") + with open(text_file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + + with zipfile.ZipFile(zip_file_path, "w") as zip_file: + zip_file.write(text_file_path) + + return text_file_path, zip_file_path + + def _test_file_extraction_and_validation( + self, dest_dir, file_path, archive_type + ): + """Helper function for file extraction and validation.""" + origin = urllib.parse.urljoin( + "file://", + urllib.request.pathname2url(os.path.abspath(file_path)), + ) + + hashval_sha256 = file_utils.hash_file(file_path) + hashval_md5 = file_utils.hash_file(file_path, algorithm="md5") + + if archive_type: + extract = True + else: + extract = False + + path = file_utils.get_file( + "test", + origin, + md5_hash=hashval_md5, + extract=extract, + cache_subdir=dest_dir, + ) + path = file_utils.get_file( + "test", + origin, + file_hash=hashval_sha256, + extract=extract, + cache_subdir=dest_dir, + ) + self.assertTrue(os.path.exists(path)) + self.assertTrue(file_utils.validate_file(path, hashval_sha256)) + self.assertTrue(file_utils.validate_file(path, hashval_md5)) + os.remove(path) + def test_is_remote_path(self): self.assertTrue(file_utils.is_remote_path("gs://bucket/path")) self.assertTrue(file_utils.is_remote_path("http://example.com/path")) From dd4ee8203ba329646f21db806c16206a2f28f4bd Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:44:27 +0000 Subject: [PATCH 22/39] Add tests in file_utils_test.py --- keras_core/utils/file_utils.py | 7 ++++--- keras_core/utils/file_utils_test.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 8068c8f78..320b3cbf0 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -343,9 +343,10 @@ def hash_file(fpath, algorithm="sha256", chunk_size=65535): 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' Args: - fpath: Path to the file being validated. - algorithm: Hash algorithm, one of `"sha256"` or `"md5"`. - chunk_size: Bytes to read at a time, important for large files + fpath: Path to the file being validated. + algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. + The default `"auto"` detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. Returns: The file hash. diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 8f74e7e0b..3918b30e2 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -171,10 +171,11 @@ def test_invalid_path_warning(self): with tarfile.open(self.tar_path, "r") as tar: with patch("warnings.warn") as mock_warn: _ = list(file_utils.filter_safe_paths(tar.getmembers())) - mock_warn.assert_called_with( - "Skipping invalid path during archive extraction: '../../invalid.txt'.", - stacklevel=2, + warning_msg = ( + "Skipping invalid path during archive extraction: " + "'../../invalid.txt'." ) + mock_warn.assert_called_with(warning_msg, stacklevel=2) class ExtractArchiveTest(test_case.TestCase): From 2213296c5555a552ef92c9c9b949a4af3a45a8d7 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 07:45:53 +0000 Subject: [PATCH 23/39] Add tests in file_utils_test.py --- invalid.txt | 1 - 1 file changed, 1 deletion(-) delete mode 100644 invalid.txt diff --git a/invalid.txt b/invalid.txt deleted file mode 100644 index e466dcbd8..000000000 --- a/invalid.txt +++ /dev/null @@ -1 +0,0 @@ -invalid \ No newline at end of file From 038021fe66f0d4c39942c89ba0853ae4ea916540 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:51:57 +0000 Subject: [PATCH 24/39] Add tests in `file_utils_test.py` --- keras_core/utils/file_utils_test.py | 278 ++++++++++++++-------------- 1 file changed, 139 insertions(+), 139 deletions(-) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 3918b30e2..c1346f357 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -257,145 +257,6 @@ def test_archive_format_none(self): self.assertFalse(result) -class HashFileTest(test_case.TestCase): - def setUp(self): - self.test_content = b"Hello, World!" - self.temp_file = tempfile.NamedTemporaryFile(delete=False) - self.temp_file.write(self.test_content) - self.temp_file.close() - - def tearDown(self): - os.remove(self.temp_file.name) - - def test_hash_file_sha256(self): - """Test SHA256 hashing of a file.""" - expected_sha256 = ( - "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" - ) - calculated_sha256 = file_utils.hash_file( - self.temp_file.name, algorithm="sha256" - ) - self.assertEqual(expected_sha256, calculated_sha256) - - def test_hash_file_md5(self): - """Test MD5 hashing of a file.""" - expected_md5 = "65a8e27d8879283831b664bd8b7f0ad4" - calculated_md5 = file_utils.hash_file( - self.temp_file.name, algorithm="md5" - ) - self.assertEqual(expected_md5, calculated_md5) - - -class TestValidateFile(test_case.TestCase): - def setUp(self): - self.tmp_file = tempfile.NamedTemporaryFile(delete=False) - self.tmp_file.write(b"Hello, World!") - self.tmp_file.close() - - self.sha256_hash = ( - "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" - ) - self.md5_hash = "65a8e27d8879283831b664bd8b7f0ad4" - - def test_validate_file_sha256(self): - """Validate SHA256 hash of a file.""" - self.assertTrue( - file_utils.validate_file( - self.tmp_file.name, self.sha256_hash, "sha256" - ) - ) - - def test_validate_file_md5(self): - """Validate MD5 hash of a file.""" - self.assertTrue( - file_utils.validate_file(self.tmp_file.name, self.md5_hash, "md5") - ) - - def test_validate_file_auto_sha256(self): - """Auto-detect and validate SHA256 hash.""" - self.assertTrue( - file_utils.validate_file( - self.tmp_file.name, self.sha256_hash, "auto" - ) - ) - - def test_validate_file_auto_md5(self): - """Auto-detect and validate MD5 hash.""" - self.assertTrue( - file_utils.validate_file(self.tmp_file.name, self.md5_hash, "auto") - ) - - def test_validate_file_wrong_hash(self): - """Test validation with incorrect hash.""" - wrong_hash = "deadbeef" * 8 - self.assertFalse( - file_utils.validate_file(self.tmp_file.name, wrong_hash, "sha256") - ) - - def tearDown(self): - os.remove(self.tmp_file.name) - - -class ResolveHasherTest(test_case.TestCase): - def test_resolve_hasher_sha256(self): - """Test resolving hasher for sha256 algorithm.""" - hasher = file_utils.resolve_hasher("sha256") - self.assertIsInstance(hasher, type(hashlib.sha256())) - - def test_resolve_hasher_auto_sha256(self): - """Auto-detect and resolve hasher for sha256.""" - hasher = file_utils.resolve_hasher("auto", file_hash="a" * 64) - self.assertIsInstance(hasher, type(hashlib.sha256())) - - def test_resolve_hasher_auto_md5(self): - """Auto-detect and resolve hasher for md5.""" - hasher = file_utils.resolve_hasher("auto", file_hash="a" * 32) - self.assertIsInstance(hasher, type(hashlib.md5())) - - def test_resolve_hasher_default(self): - """Resolve hasher with a random algorithm value.""" - hasher = file_utils.resolve_hasher("random_value") - self.assertIsInstance(hasher, type(hashlib.md5())) - - -class IsRemotePath(test_case.TestCase): - def test_gcs_remote_path(self): - self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) - self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) - - def test_cns_remote_path(self): - self.assertTrue(file_utils.is_remote_path("/cns/some/path/to/file.txt")) - self.assertTrue(file_utils.is_remote_path("/cns/another/directory/")) - - def test_cfs_remote_path(self): - self.assertTrue(file_utils.is_remote_path("/cfs/some/path/to/file.txt")) - self.assertTrue(file_utils.is_remote_path("/cfs/another/directory/")) - - def test_http_remote_path(self): - self.assertTrue( - file_utils.is_remote_path("http://example.com/path/to/file.txt") - ) - self.assertTrue( - file_utils.is_remote_path("https://secure.example.com/directory/") - ) - self.assertTrue( - file_utils.is_remote_path("ftp://files.example.com/somefile.txt") - ) - - def test_non_remote_paths(self): - self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) - self.assertFalse( - file_utils.is_remote_path("C:\\local\\path\\on\\windows\\file.txt") - ) - self.assertFalse(file_utils.is_remote_path("~/relative/path/")) - self.assertFalse(file_utils.is_remote_path("./another/relative/path")) - - def test_edge_cases(self): - self.assertFalse(file_utils.is_remote_path("")) - self.assertFalse(file_utils.is_remote_path(None)) - self.assertFalse(file_utils.is_remote_path(12345)) - - class GetFileTest(test_case.TestCase): def setUp(self): """Set up temporary directories and sample files.""" @@ -643,3 +504,142 @@ def test_handle_complex_paths(self): file_utils.makedirs(complex_dir) file_utils.rmtree(complex_dir) self.assertFalse(os.path.exists(complex_dir)) + + +class HashFileTest(test_case.TestCase): + def setUp(self): + self.test_content = b"Hello, World!" + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.temp_file.write(self.test_content) + self.temp_file.close() + + def tearDown(self): + os.remove(self.temp_file.name) + + def test_hash_file_sha256(self): + """Test SHA256 hashing of a file.""" + expected_sha256 = ( + "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + ) + calculated_sha256 = file_utils.hash_file( + self.temp_file.name, algorithm="sha256" + ) + self.assertEqual(expected_sha256, calculated_sha256) + + def test_hash_file_md5(self): + """Test MD5 hashing of a file.""" + expected_md5 = "65a8e27d8879283831b664bd8b7f0ad4" + calculated_md5 = file_utils.hash_file( + self.temp_file.name, algorithm="md5" + ) + self.assertEqual(expected_md5, calculated_md5) + + +class TestValidateFile(test_case.TestCase): + def setUp(self): + self.tmp_file = tempfile.NamedTemporaryFile(delete=False) + self.tmp_file.write(b"Hello, World!") + self.tmp_file.close() + + self.sha256_hash = ( + "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + ) + self.md5_hash = "65a8e27d8879283831b664bd8b7f0ad4" + + def test_validate_file_sha256(self): + """Validate SHA256 hash of a file.""" + self.assertTrue( + file_utils.validate_file( + self.tmp_file.name, self.sha256_hash, "sha256" + ) + ) + + def test_validate_file_md5(self): + """Validate MD5 hash of a file.""" + self.assertTrue( + file_utils.validate_file(self.tmp_file.name, self.md5_hash, "md5") + ) + + def test_validate_file_auto_sha256(self): + """Auto-detect and validate SHA256 hash.""" + self.assertTrue( + file_utils.validate_file( + self.tmp_file.name, self.sha256_hash, "auto" + ) + ) + + def test_validate_file_auto_md5(self): + """Auto-detect and validate MD5 hash.""" + self.assertTrue( + file_utils.validate_file(self.tmp_file.name, self.md5_hash, "auto") + ) + + def test_validate_file_wrong_hash(self): + """Test validation with incorrect hash.""" + wrong_hash = "deadbeef" * 8 + self.assertFalse( + file_utils.validate_file(self.tmp_file.name, wrong_hash, "sha256") + ) + + def tearDown(self): + os.remove(self.tmp_file.name) + + +class ResolveHasherTest(test_case.TestCase): + def test_resolve_hasher_sha256(self): + """Test resolving hasher for sha256 algorithm.""" + hasher = file_utils.resolve_hasher("sha256") + self.assertIsInstance(hasher, type(hashlib.sha256())) + + def test_resolve_hasher_auto_sha256(self): + """Auto-detect and resolve hasher for sha256.""" + hasher = file_utils.resolve_hasher("auto", file_hash="a" * 64) + self.assertIsInstance(hasher, type(hashlib.sha256())) + + def test_resolve_hasher_auto_md5(self): + """Auto-detect and resolve hasher for md5.""" + hasher = file_utils.resolve_hasher("auto", file_hash="a" * 32) + self.assertIsInstance(hasher, type(hashlib.md5())) + + def test_resolve_hasher_default(self): + """Resolve hasher with a random algorithm value.""" + hasher = file_utils.resolve_hasher("random_value") + self.assertIsInstance(hasher, type(hashlib.md5())) + + +class IsRemotePath(test_case.TestCase): + def test_gcs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) + + def test_cns_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cns/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/cns/another/directory/")) + + def test_cfs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cfs/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/cfs/another/directory/")) + + def test_http_remote_path(self): + self.assertTrue( + file_utils.is_remote_path("http://example.com/path/to/file.txt") + ) + self.assertTrue( + file_utils.is_remote_path("https://secure.example.com/directory/") + ) + self.assertTrue( + file_utils.is_remote_path("ftp://files.example.com/somefile.txt") + ) + + def test_non_remote_paths(self): + self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) + self.assertFalse( + file_utils.is_remote_path("C:\\local\\path\\on\\windows\\file.txt") + ) + self.assertFalse(file_utils.is_remote_path("~/relative/path/")) + self.assertFalse(file_utils.is_remote_path("./another/relative/path")) + + def test_edge_cases(self): + self.assertFalse(file_utils.is_remote_path("")) + self.assertFalse(file_utils.is_remote_path(None)) + self.assertFalse(file_utils.is_remote_path(12345)) From 1070aa7e715d8a5b8027bf492bbb3891fad109a0 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:17:23 +0000 Subject: [PATCH 25/39] fix `is_remote_path` --- keras_core/utils/file_utils.py | 29 +++++++++++-- keras_core/utils/file_utils_test.py | 67 ++++++++++++++++++++++++++--- 2 files changed, 87 insertions(+), 9 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 320b3cbf0..62933fb41 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -386,13 +386,36 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): def is_remote_path(filepath): - """Returns `True` for paths that represent a remote GCS location.""" - # TODO: improve generality. - if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)): + """Returns `True` for paths that represent a remote location.""" + SUPPORTED_PROTOCOLS = ["http", "https", "ftp", "gcs", "cns", "cfs"] + PROTOCOL_PATTERN = re.compile(r"^(?P\w+)://", re.IGNORECASE) + + # Convert to string in case the filepath is in a different format + filepath_str = str(filepath).strip() + + # Check for "protocol://" pattern using case-insensitive matching + match = PROTOCOL_PATTERN.match(filepath_str) + if match and match.group("protocol").lower() in SUPPORTED_PROTOCOLS: return True + + # Check for known remote path prefixes using case-insensitive matching + for protocol in SUPPORTED_PROTOCOLS: + if filepath_str.lower().startswith( + "/" + protocol + "/" + ) or filepath_str.startswith("/" + protocol + "/"): + return True + return False +# def is_remote_path(filepath): +# """Returns `True` for paths that represent a remote GCS location.""" +# # TODO: improve generality. +# if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)): +# return True +# return False + + # Below are gfile-replacement utils. diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index c1346f357..eeb736be3 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -407,12 +407,6 @@ def _test_file_extraction_and_validation( self.assertTrue(file_utils.validate_file(path, hashval_md5)) os.remove(path) - def test_is_remote_path(self): - self.assertTrue(file_utils.is_remote_path("gs://bucket/path")) - self.assertTrue(file_utils.is_remote_path("http://example.com/path")) - self.assertFalse(file_utils.is_remote_path("/local/path")) - self.assertFalse(file_utils.is_remote_path("./relative/path")) - def test_exists(self): temp_dir = self.get_temp_dir() file_path = os.path.join(temp_dir, "test_exists.txt") @@ -638,8 +632,69 @@ def test_non_remote_paths(self): ) self.assertFalse(file_utils.is_remote_path("~/relative/path/")) self.assertFalse(file_utils.is_remote_path("./another/relative/path")) + self.assertFalse(file_utils.is_remote_path("/local/path")) + self.assertFalse(file_utils.is_remote_path("./relative/path")) + self.assertFalse(file_utils.is_remote_path("~/relative/path")) def test_edge_cases(self): self.assertFalse(file_utils.is_remote_path("")) self.assertFalse(file_utils.is_remote_path(None)) self.assertFalse(file_utils.is_remote_path(12345)) + + def test_special_characters(self): + self.assertTrue(file_utils.is_remote_path("/gcs/some/päth")) + self.assertTrue(file_utils.is_remote_path("/gcs/some/path#anchor")) + self.assertTrue(file_utils.is_remote_path("/gcs/some/path?query=value")) + + +class IsRemotePathRefactoredTests(test_case.TestCase): + def test_additional_protocols(self): + # Ensure other protocols are not identified as remote + self.assertFalse(file_utils.is_remote_path("mailto:user@example.com")) + self.assertFalse( + file_utils.is_remote_path( + "data:text/plain;charset=utf-8,Hello%20World!" + ) + ) + + def test_case_sensitivity(self): + # Ensure function handles different casing + self.assertTrue(file_utils.is_remote_path("/GcS/sOme/Path")) + self.assertTrue(file_utils.is_remote_path("HTTP://eXample.Com")) + self.assertTrue(file_utils.is_remote_path("hTtP://exaMple.cOm")) + + def test_whitespace_paths(self): + # Ensure function handles paths with spaces correctly + self.assertTrue(file_utils.is_remote_path(" /gcs/some/path ")) + self.assertTrue(file_utils.is_remote_path("/gcs/ some /path")) + + def test_special_characters(self): + # Ensure function handles special characters correctly + self.assertTrue(file_utils.is_remote_path("/gcs/some/päth")) + self.assertTrue(file_utils.is_remote_path("/gcs/some/path#anchor")) + self.assertTrue(file_utils.is_remote_path("/gcs/some/path?query=value")) + self.assertTrue(file_utils.is_remote_path("/gcs/some/path with spaces")) + + # def test_http_protocol(self): + # self.assertTrue(file_utils.is_remote_path("http://example.com")) + # self.assertFalse(file_utils.is_remote_path("/http/some/local/path")) + + # def test_https_protocol(self): + # self.assertTrue(file_utils.is_remote_path("https://example.com")) + # self.assertFalse(file_utils.is_remote_path("/https/some/local/path")) + + # def test_ftp_protocol(self): + # self.assertTrue(file_utils.is_remote_path("ftp://files.example.com/somefile.txt")) + # self.assertFalse(file_utils.is_remote_path("/ftp/some/local/path")) + + # def test_gcs_protocol(self): + # self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) + # self.assertFalse(file_utils.is_remote_path("gcs://bucket/some/file.txt")) + + # def test_cns_protocol(self): + # self.assertTrue(file_utils.is_remote_path("/cns/some/path/to/file.txt")) + # self.assertFalse(file_utils.is_remote_path("cns://some/directory/")) + + # def test_cfs_protocol(self): + # self.assertTrue(file_utils.is_remote_path("/cfs/some/path/to/file.txt")) + # self.assertFalse(file_utils.is_remote_path("cfs://some/directory/")) From 9012d15e1b401cd636de5d8935d1d4d76f484244 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 13:10:06 +0000 Subject: [PATCH 26/39] improve `is_remote_path` --- keras_core/utils/file_utils.py | 26 ++------ keras_core/utils/file_utils_test.py | 99 ++++++++++++++--------------- 2 files changed, 52 insertions(+), 73 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 62933fb41..89424941a 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -387,35 +387,21 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): def is_remote_path(filepath): """Returns `True` for paths that represent a remote location.""" - SUPPORTED_PROTOCOLS = ["http", "https", "ftp", "gcs", "cns", "cfs"] - PROTOCOL_PATTERN = re.compile(r"^(?P\w+)://", re.IGNORECASE) - # Convert to string in case the filepath is in a different format filepath_str = str(filepath).strip() - # Check for "protocol://" pattern using case-insensitive matching - match = PROTOCOL_PATTERN.match(filepath_str) - if match and match.group("protocol").lower() in SUPPORTED_PROTOCOLS: - return True + # Specific patterns for supported remote paths + supported_patterns = [ + re.compile(r"^(gs|cns|cfs|http|https|ftp|s3)://.*$", re.IGNORECASE) + ] - # Check for known remote path prefixes using case-insensitive matching - for protocol in SUPPORTED_PROTOCOLS: - if filepath_str.lower().startswith( - "/" + protocol + "/" - ) or filepath_str.startswith("/" + protocol + "/"): + for pattern in supported_patterns: + if pattern.match(filepath_str): return True return False -# def is_remote_path(filepath): -# """Returns `True` for paths that represent a remote GCS location.""" -# # TODO: improve generality. -# if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)): -# return True -# return False - - # Below are gfile-replacement utils. diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index eeb736be3..4070aea46 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -176,6 +176,7 @@ def test_invalid_path_warning(self): "'../../invalid.txt'." ) mock_warn.assert_called_with(warning_msg, stacklevel=2) + os.remove(invalid_path) class ExtractArchiveTest(test_case.TestCase): @@ -602,28 +603,48 @@ def test_resolve_hasher_default(self): class IsRemotePath(test_case.TestCase): - def test_gcs_remote_path(self): - self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) - self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) + def test_gs_remote_path(self): + self.assertFalse(file_utils.is_remote_path("/gs/some/path/to/file.txt")) + self.assertFalse(file_utils.is_remote_path("/gs/another/directory/")) + self.assertTrue(file_utils.is_remote_path("gs://bucket/some/file.txt")) def test_cns_remote_path(self): - self.assertTrue(file_utils.is_remote_path("/cns/some/path/to/file.txt")) - self.assertTrue(file_utils.is_remote_path("/cns/another/directory/")) + self.assertFalse( + file_utils.is_remote_path("/cns/some/path/to/file.txt") + ) + self.assertFalse(file_utils.is_remote_path("/cns/another/directory/")) + self.assertTrue(file_utils.is_remote_path("cns://some/directory/")) def test_cfs_remote_path(self): - self.assertTrue(file_utils.is_remote_path("/cfs/some/path/to/file.txt")) - self.assertTrue(file_utils.is_remote_path("/cfs/another/directory/")) + self.assertFalse( + file_utils.is_remote_path("/cfs/some/path/to/file.txt") + ) + self.assertFalse(file_utils.is_remote_path("/cfs/another/directory/")) + self.assertTrue(file_utils.is_remote_path("cfs://some/directory/")) + + def test_s3_remote_path(self): + self.assertTrue(file_utils.is_remote_path("s3://bucket/some/file.txt")) + self.assertTrue( + file_utils.is_remote_path("s3://bucket/another/directory/") + ) + self.assertFalse(file_utils.is_remote_path("/s3/some/path/to/file.txt")) + self.assertFalse(file_utils.is_remote_path("/s3/another/directory/")) - def test_http_remote_path(self): + def test_http_and_https_remote_path(self): self.assertTrue( file_utils.is_remote_path("http://example.com/path/to/file.txt") ) self.assertTrue( file_utils.is_remote_path("https://secure.example.com/directory/") ) + self.assertFalse(file_utils.is_remote_path("/http/some/local/path")) + self.assertFalse(file_utils.is_remote_path("/https/some/local/path")) + + def test_ftp_remote_path(self): self.assertTrue( file_utils.is_remote_path("ftp://files.example.com/somefile.txt") ) + self.assertFalse(file_utils.is_remote_path("/ftp/some/local/path")) def test_non_remote_paths(self): self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) @@ -641,60 +662,32 @@ def test_edge_cases(self): self.assertFalse(file_utils.is_remote_path(None)) self.assertFalse(file_utils.is_remote_path(12345)) - def test_special_characters(self): - self.assertTrue(file_utils.is_remote_path("/gcs/some/päth")) - self.assertTrue(file_utils.is_remote_path("/gcs/some/path#anchor")) - self.assertTrue(file_utils.is_remote_path("/gcs/some/path?query=value")) + def test_special_characters_in_path(self): + self.assertTrue(file_utils.is_remote_path("gs://some/päth")) + self.assertTrue(file_utils.is_remote_path("gs://some/path#anchor")) + self.assertTrue(file_utils.is_remote_path("gs://some/path?query=value")) + self.assertTrue(file_utils.is_remote_path("gs://some/path with spaces")) - -class IsRemotePathRefactoredTests(test_case.TestCase): - def test_additional_protocols(self): - # Ensure other protocols are not identified as remote + def test_unsupported_protocols(self): self.assertFalse(file_utils.is_remote_path("mailto:user@example.com")) self.assertFalse( file_utils.is_remote_path( "data:text/plain;charset=utf-8,Hello%20World!" ) ) + self.assertFalse(file_utils.is_remote_path("file://local/path")) def test_case_sensitivity(self): - # Ensure function handles different casing - self.assertTrue(file_utils.is_remote_path("/GcS/sOme/Path")) + self.assertTrue(file_utils.is_remote_path("gs://sOme/Path")) self.assertTrue(file_utils.is_remote_path("HTTP://eXample.Com")) self.assertTrue(file_utils.is_remote_path("hTtP://exaMple.cOm")) - def test_whitespace_paths(self): - # Ensure function handles paths with spaces correctly - self.assertTrue(file_utils.is_remote_path(" /gcs/some/path ")) - self.assertTrue(file_utils.is_remote_path("/gcs/ some /path")) - - def test_special_characters(self): - # Ensure function handles special characters correctly - self.assertTrue(file_utils.is_remote_path("/gcs/some/päth")) - self.assertTrue(file_utils.is_remote_path("/gcs/some/path#anchor")) - self.assertTrue(file_utils.is_remote_path("/gcs/some/path?query=value")) - self.assertTrue(file_utils.is_remote_path("/gcs/some/path with spaces")) - - # def test_http_protocol(self): - # self.assertTrue(file_utils.is_remote_path("http://example.com")) - # self.assertFalse(file_utils.is_remote_path("/http/some/local/path")) - - # def test_https_protocol(self): - # self.assertTrue(file_utils.is_remote_path("https://example.com")) - # self.assertFalse(file_utils.is_remote_path("/https/some/local/path")) - - # def test_ftp_protocol(self): - # self.assertTrue(file_utils.is_remote_path("ftp://files.example.com/somefile.txt")) - # self.assertFalse(file_utils.is_remote_path("/ftp/some/local/path")) - - # def test_gcs_protocol(self): - # self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) - # self.assertFalse(file_utils.is_remote_path("gcs://bucket/some/file.txt")) - - # def test_cns_protocol(self): - # self.assertTrue(file_utils.is_remote_path("/cns/some/path/to/file.txt")) - # self.assertFalse(file_utils.is_remote_path("cns://some/directory/")) - - # def test_cfs_protocol(self): - # self.assertTrue(file_utils.is_remote_path("/cfs/some/path/to/file.txt")) - # self.assertFalse(file_utils.is_remote_path("cfs://some/directory/")) + def test_whitespace_in_paths(self): + self.assertTrue(file_utils.is_remote_path(" gs://some/path ")) + self.assertTrue(file_utils.is_remote_path("gs:// some /path")) + + def test_false_positives(self): + self.assertFalse(file_utils.is_remote_path("/httpslocal/some/path")) + self.assertFalse(file_utils.is_remote_path("/gslocal/some/path")) + self.assertFalse(file_utils.is_remote_path("/cnslocal/some/path")) + self.assertFalse(file_utils.is_remote_path("/cfslocal/some/path")) From e1eb4e4194283c4bb16268c5031dc2d7df3cbd6d Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 13:39:17 +0000 Subject: [PATCH 27/39] Add test for `raise_if_no_gfile_raises` --- keras_core/utils/file_utils_test.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 4070aea46..b9f817081 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -602,7 +602,7 @@ def test_resolve_hasher_default(self): self.assertIsInstance(hasher, type(hashlib.md5())) -class IsRemotePath(test_case.TestCase): +class IsRemotePathTest(test_case.TestCase): def test_gs_remote_path(self): self.assertFalse(file_utils.is_remote_path("/gs/some/path/to/file.txt")) self.assertFalse(file_utils.is_remote_path("/gs/another/directory/")) @@ -691,3 +691,14 @@ def test_false_positives(self): self.assertFalse(file_utils.is_remote_path("/gslocal/some/path")) self.assertFalse(file_utils.is_remote_path("/cnslocal/some/path")) self.assertFalse(file_utils.is_remote_path("/cfslocal/some/path")) + + +class TestRaiseIfNoGFile(test_case.TestCase): + def test_raise_if_no_gfile_raises_correct_message(self): + path = "gs://bucket/some/file.txt" + expected_error_msg = ( + "Handling remote paths requires installing TensorFlow " + f".*Received path: {path}" + ) + with self.assertRaisesRegex(ValueError, expected_error_msg): + file_utils._raise_if_no_gfile(path) From 51385120f03bf7db73a06a59947f69a65fcbee17 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 14:18:18 +0000 Subject: [PATCH 28/39] Add tests for file_utils.py --- keras_core/utils/file_utils_test.py | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index b9f817081..c418b5a26 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -178,6 +178,32 @@ def test_invalid_path_warning(self): mock_warn.assert_called_with(warning_msg, stacklevel=2) os.remove(invalid_path) + def test_symbolic_link_in_base_dir(self): + """Test a symbolic link within the base directory is correctly processed.""" + symlink_path = os.path.join(self.base_dir, "symlink.txt") + target_path = os.path.join(self.base_dir, "target.txt") + + # Create a target file and then a symbolic link pointing to it. + with open(target_path, "w") as f: + f.write("target") + os.symlink(target_path, symlink_path) + + # Add the symbolic link to the tar archive. + with tarfile.open(self.tar_path, "w") as tar: + tar.add(symlink_path, arcname="symlink.txt") + + # Open the tar archive and check if the symbolic link is correctly processed. + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "symlink.txt") + self.assertTrue( + members[0].issym() + ) # Explicitly assert it's a symbolic link. + + os.remove(symlink_path) + os.remove(target_path) + class ExtractArchiveTest(test_case.TestCase): def setUp(self): @@ -257,6 +283,30 @@ def test_archive_format_none(self): result = file_utils.extract_archive(archive_path, extract_path, None) self.assertFalse(result) + def test_runtime_error_during_extraction(self): + tar_path = self.create_tar() + extract_path = os.path.join(self.temp_dir, "runtime_error_extraction") + + with patch.object( + tarfile.TarFile, "extractall", side_effect=RuntimeError + ): + with self.assertRaises(RuntimeError): + file_utils.extract_archive(tar_path, extract_path, "tar") + self.assertFalse(os.path.exists(extract_path)) + + def test_keyboard_interrupt_during_extraction(self): + tar_path = self.create_tar() + extract_path = os.path.join( + self.temp_dir, "keyboard_interrupt_extraction" + ) + + with patch.object( + tarfile.TarFile, "extractall", side_effect=KeyboardInterrupt + ): + with self.assertRaises(KeyboardInterrupt): + file_utils.extract_archive(tar_path, extract_path, "tar") + self.assertFalse(os.path.exists(extract_path)) + class GetFileTest(test_case.TestCase): def setUp(self): From a56723961a1ebb07344e60a584db1a793382cd56 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 15:32:29 +0000 Subject: [PATCH 29/39] Add tests in `saving_api_test.py` --- keras_core/saving/saving_api_test.py | 55 +++++++++++++++++++++++++--- keras_core/utils/file_utils_test.py | 3 +- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index ce7c468da..1916d6bac 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -1,14 +1,15 @@ import os -import unittest +from unittest import mock import numpy as np from keras_core import layers from keras_core.models import Sequential from keras_core.saving import saving_api +from keras_core.testing import test_case -class SaveModelTests(unittest.TestCase): +class SaveModelTests(test_case.TestCase): def setUp(self): self.model = Sequential( [ @@ -21,6 +22,10 @@ def setUp(self): os.remove(self.filepath) saving_api.save_model(self.model, self.filepath) + def test_h5_deprecation_warning(self): + with self.assertWarns(UserWarning): + saving_api.save_model(self.model, "test_model.h5") + def test_basic_saving(self): loaded_model = saving_api.load_model(self.filepath) x = np.random.uniform(size=(10, 3)) @@ -44,7 +49,7 @@ def test_save_h5_format(self): filepath_h5 = "test_model.h5" saving_api.save_model(self.model, filepath_h5) self.assertTrue(os.path.exists(filepath_h5)) - os.remove(filepath_h5) # Cleanup + os.remove(filepath_h5) def test_save_unsupported_extension(self): with self.assertRaisesRegex( @@ -56,8 +61,23 @@ def tearDown(self): if os.path.exists(self.filepath): os.remove(self.filepath) - -class LoadModelTests(unittest.TestCase): + def test_h5_deprecation_warning(self): + with self.assertLogs(level="WARNING") as log: + saving_api.save_model(self.model, "test_model.h5") + expected_warning_msg = ( + "You are saving your model as an HDF5 file via `model.save()`" + ) + matched_logs = [ + msg for msg in log.output if expected_warning_msg in msg + ] + self.assertEqual( + len(matched_logs), + 1, + f"Expected warning message not found in logs: {log.output}", + ) + + +class LoadModelTests(test_case.TestCase): def setUp(self): self.model = Sequential( [ @@ -97,8 +117,21 @@ def tearDown(self): if os.path.exists(self.filepath): os.remove(self.filepath) + def test_load_model_with_custom_objects(self): + class CustomLayer(layers.Layer): + def call(self, inputs): + return inputs + + model = Sequential([CustomLayer(input_shape=(3,))]) + model.save("custom_model.keras") + loaded_model = saving_api.load_model( + "custom_model.keras", custom_objects={"CustomLayer": CustomLayer} + ) + self.assertIsInstance(loaded_model.layers[0], CustomLayer) + os.remove("custom_model.keras") + -class LoadWeightsTests(unittest.TestCase): +class LoadWeightsTests(test_case.TestCase): def setUp(self): self.model = Sequential( [ @@ -120,3 +153,13 @@ def tearDown(self): filepath = "test_weights.weights.h5" if os.path.exists(filepath): os.remove(filepath) + + def test_load_h5_weights_by_name(self): + filepath = "test_weights.weights.h5" + self.model.save_weights(filepath) + with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"): + self.model.load_weights(filepath, by_name=True) + + def test_load_weights_invalid_extension(self): + with self.assertRaisesRegex(ValueError, "File format not supported"): + self.model.load_weights("invalid_extension.pkl") diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index c418b5a26..1b8d516dd 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -179,7 +179,7 @@ def test_invalid_path_warning(self): os.remove(invalid_path) def test_symbolic_link_in_base_dir(self): - """Test a symbolic link within the base directory is correctly processed.""" + """symbolic link within the base directory is correctly processed.""" symlink_path = os.path.join(self.base_dir, "symlink.txt") target_path = os.path.join(self.base_dir, "target.txt") @@ -192,7 +192,6 @@ def test_symbolic_link_in_base_dir(self): with tarfile.open(self.tar_path, "w") as tar: tar.add(symlink_path, arcname="symlink.txt") - # Open the tar archive and check if the symbolic link is correctly processed. with tarfile.open(self.tar_path, "r") as tar: members = list(file_utils.filter_safe_paths(tar.getmembers())) self.assertEqual(len(members), 1) From 81079abef6eac1d380024a7c55884c23e001e4b7 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 15:43:57 +0000 Subject: [PATCH 30/39] Add tests `saving_api_test.py` --- keras_core/saving/saving_api_test.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index 1916d6bac..38c275bf5 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -1,5 +1,4 @@ import os -from unittest import mock import numpy as np @@ -61,21 +60,6 @@ def tearDown(self): if os.path.exists(self.filepath): os.remove(self.filepath) - def test_h5_deprecation_warning(self): - with self.assertLogs(level="WARNING") as log: - saving_api.save_model(self.model, "test_model.h5") - expected_warning_msg = ( - "You are saving your model as an HDF5 file via `model.save()`" - ) - matched_logs = [ - msg for msg in log.output if expected_warning_msg in msg - ] - self.assertEqual( - len(matched_logs), - 1, - f"Expected warning message not found in logs: {log.output}", - ) - class LoadModelTests(test_case.TestCase): def setUp(self): From 32ec488eef1ea095165d560799605301da831aff Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:11:43 +0000 Subject: [PATCH 31/39] Add tests saving_api_test.py --- keras_core/saving/saving_api_test.py | 34 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index 38c275bf5..8f56046c8 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -1,6 +1,8 @@ import os +import unittest.mock as mock import numpy as np +from absl import logging from keras_core import layers from keras_core.models import Sequential @@ -21,10 +23,6 @@ def setUp(self): os.remove(self.filepath) saving_api.save_model(self.model, self.filepath) - def test_h5_deprecation_warning(self): - with self.assertWarns(UserWarning): - saving_api.save_model(self.model, "test_model.h5") - def test_basic_saving(self): loaded_model = saving_api.load_model(self.filepath) x = np.random.uniform(size=(10, 3)) @@ -147,3 +145,31 @@ def test_load_h5_weights_by_name(self): def test_load_weights_invalid_extension(self): with self.assertRaisesRegex(ValueError, "File format not supported"): self.model.load_weights("invalid_extension.pkl") + + +class SaveModelTestsWarning(test_case.TestCase): + def setUp(self): + self.model = Sequential( + [ + layers.Dense(5, input_shape=(3,)), + layers.Softmax(), + ], + ) + self.filepath = "test_model.keras" + if os.path.exists(self.filepath): + os.remove(self.filepath) + saving_api.save_model(self.model, self.filepath) + + def test_h5_deprecation_warning(self): + with mock.patch.object(logging, "warning") as mock_warn: + saving_api.save_model(self.model, "test_model.h5") + mock_warn.assert_called_once_with( + "You are saving your model as an HDF5 file via `model.save()`. " + "This file format is considered legacy. " + "We recommend using instead the native Keras format, " + "e.g. `model.save('my_model.keras')`." + ) + + def tearDown(self): + if os.path.exists(self.filepath): + os.remove(self.filepath) From 874761844656a07998774faf459e8abaf3d12011 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:40:28 +0000 Subject: [PATCH 32/39] Add tests in `saving_api_test.py` --- keras_core/saving/saving_api_test.py | 33 ++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index 8f56046c8..974bf7aad 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -24,6 +24,7 @@ def setUp(self): saving_api.save_model(self.model, self.filepath) def test_basic_saving(self): + """Test basic model saving and loading.""" loaded_model = saving_api.load_model(self.filepath) x = np.random.uniform(size=(10, 3)) self.assertTrue( @@ -31,24 +32,28 @@ def test_basic_saving(self): ) def test_invalid_save_format(self): + """Test deprecated save_format argument.""" with self.assertRaisesRegex( ValueError, "The `save_format` argument is deprecated" ): saving_api.save_model(self.model, "model.txt", save_format=True) def test_unsupported_arguments(self): + """Test unsupported argument during model save.""" with self.assertRaisesRegex( ValueError, r"The following argument\(s\) are not supported" ): saving_api.save_model(self.model, self.filepath, random_arg=True) def test_save_h5_format(self): + """Test saving model in h5 format.""" filepath_h5 = "test_model.h5" saving_api.save_model(self.model, filepath_h5) self.assertTrue(os.path.exists(filepath_h5)) os.remove(filepath_h5) def test_save_unsupported_extension(self): + """Test saving model with unsupported extension.""" with self.assertRaisesRegex( ValueError, "Invalid filepath extension for saving" ): @@ -71,6 +76,7 @@ def setUp(self): saving_api.save_model(self.model, self.filepath) def test_basic_load(self): + """Test basic model loading.""" loaded_model = saving_api.load_model(self.filepath) x = np.random.uniform(size=(10, 3)) self.assertTrue( @@ -78,14 +84,17 @@ def test_basic_load(self): ) def test_load_unsupported_format(self): + """Test loading model with unsupported format.""" with self.assertRaisesRegex(ValueError, "File format not supported"): saving_api.load_model("model.pkl") def test_load_keras_not_zip(self): + """Test loading keras file that's not a zip.""" with self.assertRaisesRegex(ValueError, "File not found"): saving_api.load_model("not_a_zip.keras") def test_load_h5_format(self): + """Test loading model in h5 format.""" filepath_h5 = "test_model.h5" saving_api.save_model(self.model, filepath_h5) loaded_model = saving_api.load_model(filepath_h5) @@ -95,11 +104,9 @@ def test_load_h5_format(self): ) os.remove(filepath_h5) - def tearDown(self): - if os.path.exists(self.filepath): - os.remove(self.filepath) - def test_load_model_with_custom_objects(self): + """Test loading model with custom objects.""" + class CustomLayer(layers.Layer): def call(self, inputs): return inputs @@ -112,6 +119,10 @@ def call(self, inputs): self.assertIsInstance(loaded_model.layers[0], CustomLayer) os.remove("custom_model.keras") + def tearDown(self): + if os.path.exists(self.filepath): + os.remove(self.filepath) + class LoadWeightsTests(test_case.TestCase): def setUp(self): @@ -123,6 +134,7 @@ def setUp(self): ) def test_load_keras_weights(self): + """Test loading keras weights.""" filepath = "test_weights.weights.h5" self.model.save_weights(filepath) original_weights = self.model.get_weights() @@ -131,21 +143,23 @@ def test_load_keras_weights(self): for orig, loaded in zip(original_weights, loaded_weights): self.assertTrue(np.array_equal(orig, loaded)) - def tearDown(self): - filepath = "test_weights.weights.h5" - if os.path.exists(filepath): - os.remove(filepath) - def test_load_h5_weights_by_name(self): + """Test loading h5 weights by name.""" filepath = "test_weights.weights.h5" self.model.save_weights(filepath) with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"): self.model.load_weights(filepath, by_name=True) def test_load_weights_invalid_extension(self): + """Test loading weights with unsupported extension.""" with self.assertRaisesRegex(ValueError, "File format not supported"): self.model.load_weights("invalid_extension.pkl") + def tearDown(self): + filepath = "test_weights.weights.h5" + if os.path.exists(filepath): + os.remove(filepath) + class SaveModelTestsWarning(test_case.TestCase): def setUp(self): @@ -161,6 +175,7 @@ def setUp(self): saving_api.save_model(self.model, self.filepath) def test_h5_deprecation_warning(self): + """Test deprecation warning for h5 format.""" with mock.patch.object(logging, "warning") as mock_warn: saving_api.save_model(self.model, "test_model.h5") mock_warn.assert_called_once_with( From f48f07d74a47dab0fd4fb03d90577cdd46a804cf Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:23:30 +0000 Subject: [PATCH 33/39] Add test `test_directory_creation_on_save` --- keras_core/legacy/saving/legacy_h5_format_test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/keras_core/legacy/saving/legacy_h5_format_test.py b/keras_core/legacy/saving/legacy_h5_format_test.py index 6d8b63493..4333b0f35 100644 --- a/keras_core/legacy/saving/legacy_h5_format_test.py +++ b/keras_core/legacy/saving/legacy_h5_format_test.py @@ -481,3 +481,18 @@ def call(self, x): # Compare output self.assertAllClose(ref_output, output, atol=1e-5) + + +@pytest.mark.requires_trainable_backend +class DirectoryCreationTest(testing.TestCase): + def test_directory_creation_on_save(self): + model = get_sequential_model(keras_core) + nested_dirpath = os.path.join( + self.get_temp_dir(), "dir1", "dir2", "dir3" + ) + filepath = os.path.join(nested_dirpath, "model.h5") + self.assertFalse(os.path.exists(nested_dirpath)) + legacy_h5_format.save_model_to_hdf5(model, filepath) + self.assertTrue(os.path.exists(nested_dirpath)) + loaded_model = legacy_h5_format.load_model_from_hdf5(filepath) + self.assertEqual(model.to_json(), loaded_model.to_json()) From 93782de47c279aed8fcddd8ac771d7318707602e Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:33:31 +0000 Subject: [PATCH 34/39] Add test `legacy_h5_format_test.py` --- keras_core/legacy/saving/legacy_h5_format_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_core/legacy/saving/legacy_h5_format_test.py b/keras_core/legacy/saving/legacy_h5_format_test.py index 4333b0f35..02b50be04 100644 --- a/keras_core/legacy/saving/legacy_h5_format_test.py +++ b/keras_core/legacy/saving/legacy_h5_format_test.py @@ -486,6 +486,7 @@ def call(self, x): @pytest.mark.requires_trainable_backend class DirectoryCreationTest(testing.TestCase): def test_directory_creation_on_save(self): + """Test if directory is created on model save.""" model = get_sequential_model(keras_core) nested_dirpath = os.path.join( self.get_temp_dir(), "dir1", "dir2", "dir3" From aef9b2aad5978b367b390d968981a7e06933cebf Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:08:30 +0000 Subject: [PATCH 35/39] Flake8 for `LambdaCallbackTest` --- keras_core/callbacks/lambda_callback_test.py | 50 +++++++++----------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/keras_core/callbacks/lambda_callback_test.py b/keras_core/callbacks/lambda_callback_test.py index e043c765d..d6b26c7ec 100644 --- a/keras_core/callbacks/lambda_callback_test.py +++ b/keras_core/callbacks/lambda_callback_test.py @@ -12,11 +12,11 @@ class LambdaCallbackTest(testing.TestCase): @pytest.mark.requires_trainable_backend - def test_LambdaCallback(self): + def test_lambda_callback(self): """Test standard LambdaCallback functionalities with training.""" - BATCH_SIZE = 4 + batch_size = 4 model = Sequential( - [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] ) model.compile( optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() @@ -35,27 +35,23 @@ def test_LambdaCallback(self): model.fit( x, y, - batch_size=BATCH_SIZE, + batch_size=batch_size, validation_split=0.2, callbacks=[lambda_log_callback], epochs=5, verbose=0, ) - self.assertTrue - (any("on_train_begin" in log for log in logs.output)) - self.assertTrue - (any("on_epoch_begin" in log for log in logs.output)) - self.assertTrue - (any("on_epoch_end" in log for log in logs.output)) - self.assertTrue - (any("on_train_end" in log for log in logs.output)) + self.assertTrue(any("on_train_begin" in log for log in logs.output)) + self.assertTrue(any("on_epoch_begin" in log for log in logs.output)) + self.assertTrue(any("on_epoch_end" in log for log in logs.output)) + self.assertTrue(any("on_train_end" in log for log in logs.output)) @pytest.mark.requires_trainable_backend - def test_LambdaCallback_with_batches(self): + def test_lambda_callback_with_batches(self): """Test LambdaCallback's behavior with batch-level callbacks.""" - BATCH_SIZE = 4 + batch_size = 4 model = Sequential( - [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] ) model.compile( optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() @@ -74,7 +70,7 @@ def test_LambdaCallback_with_batches(self): model.fit( x, y, - batch_size=BATCH_SIZE, + batch_size=batch_size, validation_split=0.2, callbacks=[lambda_log_callback], epochs=5, @@ -88,11 +84,11 @@ def test_LambdaCallback_with_batches(self): ) @pytest.mark.requires_trainable_backend - def test_LambdaCallback_with_kwargs(self): + def test_lambda_callback_with_kwargs(self): """Test LambdaCallback's behavior with custom defined callback.""" - BATCH_SIZE = 4 + batch_size = 4 model = Sequential( - [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] ) model.compile( optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() @@ -100,7 +96,7 @@ def test_LambdaCallback_with_kwargs(self): x = np.random.randn(16, 2) y = np.random.randn(16, 1) model.fit( - x, y, batch_size=BATCH_SIZE, epochs=1, verbose=0 + x, y, batch_size=batch_size, epochs=1, verbose=0 ) # Train briefly for evaluation to work. def custom_on_test_begin(logs): @@ -113,7 +109,7 @@ def custom_on_test_begin(logs): model.evaluate( x, y, - batch_size=BATCH_SIZE, + batch_size=batch_size, callbacks=[lambda_log_callback], verbose=0, ) @@ -125,13 +121,13 @@ def custom_on_test_begin(logs): ) @pytest.mark.requires_trainable_backend - def test_LambdaCallback_no_args(self): + def test_lambda_callback_no_args(self): """Test initializing LambdaCallback without any arguments.""" lambda_callback = callbacks.LambdaCallback() self.assertIsInstance(lambda_callback, callbacks.LambdaCallback) @pytest.mark.requires_trainable_backend - def test_LambdaCallback_with_additional_kwargs(self): + def test_lambda_callback_with_additional_kwargs(self): """Test initializing LambdaCallback with non-predefined kwargs.""" def custom_callback(logs): @@ -143,11 +139,11 @@ def custom_callback(logs): self.assertTrue(hasattr(lambda_callback, "custom_method")) @pytest.mark.requires_trainable_backend - def test_LambdaCallback_during_prediction(self): + def test_lambda_callback_during_prediction(self): """Test LambdaCallback's functionality during model prediction.""" - BATCH_SIZE = 4 + batch_size = 4 model = Sequential( - [layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)] + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] ) model.compile( optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() @@ -162,7 +158,7 @@ def custom_on_predict_begin(logs): ) with self.assertLogs(level="WARNING") as logs: model.predict( - x, batch_size=BATCH_SIZE, callbacks=[lambda_callback], verbose=0 + x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0 ) self.assertTrue( any("on_predict_begin_executed" in log for log in logs.output) From fe8bd58fd1872217fbea1386de26e20faa9cb584 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Wed, 20 Sep 2023 05:37:09 +0000 Subject: [PATCH 36/39] use `get_model` and `self.get_temp_dir` --- keras_core/saving/saving_api_test.py | 128 ++++++++++++--------------- keras_core/utils/file_utils.py | 12 +-- 2 files changed, 64 insertions(+), 76 deletions(-) diff --git a/keras_core/saving/saving_api_test.py b/keras_core/saving/saving_api_test.py index 974bf7aad..d31981a1e 100644 --- a/keras_core/saving/saving_api_test.py +++ b/keras_core/saving/saving_api_test.py @@ -11,77 +11,76 @@ class SaveModelTests(test_case.TestCase): - def setUp(self): - self.model = Sequential( + def get_model(self): + return Sequential( [ layers.Dense(5, input_shape=(3,)), layers.Softmax(), - ], + ] ) - self.filepath = "test_model.keras" - if os.path.exists(self.filepath): - os.remove(self.filepath) - saving_api.save_model(self.model, self.filepath) def test_basic_saving(self): """Test basic model saving and loading.""" - loaded_model = saving_api.load_model(self.filepath) + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_model.keras") + saving_api.save_model(model, filepath) + + loaded_model = saving_api.load_model(filepath) x = np.random.uniform(size=(10, 3)) - self.assertTrue( - np.allclose(self.model.predict(x), loaded_model.predict(x)) - ) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) def test_invalid_save_format(self): """Test deprecated save_format argument.""" + model = self.get_model() with self.assertRaisesRegex( ValueError, "The `save_format` argument is deprecated" ): - saving_api.save_model(self.model, "model.txt", save_format=True) + saving_api.save_model(model, "model.txt", save_format=True) def test_unsupported_arguments(self): """Test unsupported argument during model save.""" + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_model.keras") with self.assertRaisesRegex( ValueError, r"The following argument\(s\) are not supported" ): - saving_api.save_model(self.model, self.filepath, random_arg=True) + saving_api.save_model(model, filepath, random_arg=True) def test_save_h5_format(self): """Test saving model in h5 format.""" - filepath_h5 = "test_model.h5" - saving_api.save_model(self.model, filepath_h5) + model = self.get_model() + filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") + saving_api.save_model(model, filepath_h5) self.assertTrue(os.path.exists(filepath_h5)) os.remove(filepath_h5) def test_save_unsupported_extension(self): """Test saving model with unsupported extension.""" + model = self.get_model() with self.assertRaisesRegex( ValueError, "Invalid filepath extension for saving" ): - saving_api.save_model(self.model, "model.png") - - def tearDown(self): - if os.path.exists(self.filepath): - os.remove(self.filepath) + saving_api.save_model(model, "model.png") class LoadModelTests(test_case.TestCase): - def setUp(self): - self.model = Sequential( + def get_model(self): + return Sequential( [ layers.Dense(5, input_shape=(3,)), layers.Softmax(), - ], + ] ) - self.filepath = "test_model.keras" - saving_api.save_model(self.model, self.filepath) def test_basic_load(self): """Test basic model loading.""" - loaded_model = saving_api.load_model(self.filepath) + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_model.keras") + saving_api.save_model(model, filepath) + + loaded_model = saving_api.load_model(filepath) x = np.random.uniform(size=(10, 3)) - self.assertTrue( - np.allclose(self.model.predict(x), loaded_model.predict(x)) - ) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) def test_load_unsupported_format(self): """Test loading model with unsupported format.""" @@ -95,13 +94,12 @@ def test_load_keras_not_zip(self): def test_load_h5_format(self): """Test loading model in h5 format.""" - filepath_h5 = "test_model.h5" - saving_api.save_model(self.model, filepath_h5) + model = self.get_model() + filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") + saving_api.save_model(model, filepath_h5) loaded_model = saving_api.load_model(filepath_h5) x = np.random.uniform(size=(10, 3)) - self.assertTrue( - np.allclose(self.model.predict(x), loaded_model.predict(x)) - ) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) os.remove(filepath_h5) def test_load_model_with_custom_objects(self): @@ -112,79 +110,69 @@ def call(self, inputs): return inputs model = Sequential([CustomLayer(input_shape=(3,))]) - model.save("custom_model.keras") + filepath = os.path.join(self.get_temp_dir(), "custom_model.keras") + model.save(filepath) loaded_model = saving_api.load_model( - "custom_model.keras", custom_objects={"CustomLayer": CustomLayer} + filepath, custom_objects={"CustomLayer": CustomLayer} ) self.assertIsInstance(loaded_model.layers[0], CustomLayer) - os.remove("custom_model.keras") - - def tearDown(self): - if os.path.exists(self.filepath): - os.remove(self.filepath) + os.remove(filepath) class LoadWeightsTests(test_case.TestCase): - def setUp(self): - self.model = Sequential( + def get_model(self): + return Sequential( [ layers.Dense(5, input_shape=(3,)), layers.Softmax(), - ], + ] ) def test_load_keras_weights(self): """Test loading keras weights.""" - filepath = "test_weights.weights.h5" - self.model.save_weights(filepath) - original_weights = self.model.get_weights() - self.model.load_weights(filepath) - loaded_weights = self.model.get_weights() + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5") + model.save_weights(filepath) + original_weights = model.get_weights() + model.load_weights(filepath) + loaded_weights = model.get_weights() for orig, loaded in zip(original_weights, loaded_weights): self.assertTrue(np.array_equal(orig, loaded)) def test_load_h5_weights_by_name(self): """Test loading h5 weights by name.""" - filepath = "test_weights.weights.h5" - self.model.save_weights(filepath) + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5") + model.save_weights(filepath) with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"): - self.model.load_weights(filepath, by_name=True) + model.load_weights(filepath, by_name=True) def test_load_weights_invalid_extension(self): """Test loading weights with unsupported extension.""" + model = self.get_model() with self.assertRaisesRegex(ValueError, "File format not supported"): - self.model.load_weights("invalid_extension.pkl") - - def tearDown(self): - filepath = "test_weights.weights.h5" - if os.path.exists(filepath): - os.remove(filepath) + model.load_weights("invalid_extension.pkl") class SaveModelTestsWarning(test_case.TestCase): - def setUp(self): - self.model = Sequential( + def get_model(self): + return Sequential( [ layers.Dense(5, input_shape=(3,)), layers.Softmax(), - ], + ] ) - self.filepath = "test_model.keras" - if os.path.exists(self.filepath): - os.remove(self.filepath) - saving_api.save_model(self.model, self.filepath) def test_h5_deprecation_warning(self): """Test deprecation warning for h5 format.""" + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_model.h5") + with mock.patch.object(logging, "warning") as mock_warn: - saving_api.save_model(self.model, "test_model.h5") + saving_api.save_model(model, filepath) mock_warn.assert_called_once_with( "You are saving your model as an HDF5 file via `model.save()`. " "This file format is considered legacy. " "We recommend using instead the native Keras format, " "e.g. `model.save('my_model.keras')`." ) - - def tearDown(self): - if os.path.exists(self.filepath): - os.remove(self.filepath) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 89424941a..99b8c9bb9 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -367,12 +367,12 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): """Validates a file against a sha256 or md5 hash. Args: - fpath: path to the file being validated - file_hash: The expected hash string of the file. - The sha256 and md5 hash algorithms are both supported. - algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. - The default `"auto"` detects the hash algorithm in use. - chunk_size: Bytes to read at a time, important for large files. + fpath: path to the file being validated + file_hash: The expected hash string of the file. + The sha256 and md5 hash algorithms are both supported. + algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. + The default `"auto"` detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. Returns: Boolean, whether the file is valid. From 7d755fee39d4eb75dd4c8b6fbbebe62347e4132c Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Wed, 20 Sep 2023 05:45:38 +0000 Subject: [PATCH 37/39] Fix format --- keras_core/utils/file_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 99b8c9bb9..9b616d6f2 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -343,10 +343,10 @@ def hash_file(fpath, algorithm="sha256", chunk_size=65535): 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' Args: - fpath: Path to the file being validated. - algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. - The default `"auto"` detects the hash algorithm in use. - chunk_size: Bytes to read at a time, important for large files. + fpath: Path to the file being validated. + algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. + The default `"auto"` detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. Returns: The file hash. @@ -367,12 +367,12 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): """Validates a file against a sha256 or md5 hash. Args: - fpath: path to the file being validated - file_hash: The expected hash string of the file. - The sha256 and md5 hash algorithms are both supported. - algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. - The default `"auto"` detects the hash algorithm in use. - chunk_size: Bytes to read at a time, important for large files. + fpath: path to the file being validated + file_hash: The expected hash string of the file. + The sha256 and md5 hash algorithms are both supported. + algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. + The default `"auto"` detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. Returns: Boolean, whether the file is valid. From 9b3b36ad97b96a399f664930d1b4609476b89252 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Wed, 20 Sep 2023 06:30:56 +0000 Subject: [PATCH 38/39] Improve `is_remote_path` + Add tests --- keras_core/utils/file_utils.py | 16 +++++---- keras_core/utils/file_utils_test.py | 52 +++-------------------------- 2 files changed, 13 insertions(+), 55 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 9b616d6f2..0ee275a94 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -386,19 +386,21 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): def is_remote_path(filepath): - """Returns `True` for paths that represent a remote location.""" - # Convert to string in case the filepath is in a different format filepath_str = str(filepath).strip() - - # Specific patterns for supported remote paths - supported_patterns = [ - re.compile(r"^(gs|cns|cfs|http|https|ftp|s3)://.*$", re.IGNORECASE) - ] + supported_patterns = [re.compile(r"^(gs|hdfs)://.*$", re.IGNORECASE)] for pattern in supported_patterns: if pattern.match(filepath_str): return True + # Log or print the error message without raising an exception + warning_msg = ( + f"Warning: The path '{filepath_str}' is not recognized as a " + f"supported remote path by gfile. Supported paths are: " + f"{', '.join(['gs://', 'hdfs://'])}" + ) + print(warning_msg) + return False diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 1b8d516dd..06327830f 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -657,43 +657,9 @@ def test_gs_remote_path(self): self.assertFalse(file_utils.is_remote_path("/gs/another/directory/")) self.assertTrue(file_utils.is_remote_path("gs://bucket/some/file.txt")) - def test_cns_remote_path(self): - self.assertFalse( - file_utils.is_remote_path("/cns/some/path/to/file.txt") - ) - self.assertFalse(file_utils.is_remote_path("/cns/another/directory/")) - self.assertTrue(file_utils.is_remote_path("cns://some/directory/")) - - def test_cfs_remote_path(self): - self.assertFalse( - file_utils.is_remote_path("/cfs/some/path/to/file.txt") - ) - self.assertFalse(file_utils.is_remote_path("/cfs/another/directory/")) - self.assertTrue(file_utils.is_remote_path("cfs://some/directory/")) - - def test_s3_remote_path(self): - self.assertTrue(file_utils.is_remote_path("s3://bucket/some/file.txt")) - self.assertTrue( - file_utils.is_remote_path("s3://bucket/another/directory/") - ) - self.assertFalse(file_utils.is_remote_path("/s3/some/path/to/file.txt")) - self.assertFalse(file_utils.is_remote_path("/s3/another/directory/")) - - def test_http_and_https_remote_path(self): - self.assertTrue( - file_utils.is_remote_path("http://example.com/path/to/file.txt") - ) - self.assertTrue( - file_utils.is_remote_path("https://secure.example.com/directory/") - ) - self.assertFalse(file_utils.is_remote_path("/http/some/local/path")) - self.assertFalse(file_utils.is_remote_path("/https/some/local/path")) - - def test_ftp_remote_path(self): - self.assertTrue( - file_utils.is_remote_path("ftp://files.example.com/somefile.txt") - ) - self.assertFalse(file_utils.is_remote_path("/ftp/some/local/path")) + def test_hdfs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("hdfs://some/path/on/hdfs")) + self.assertFalse(file_utils.is_remote_path("/hdfs/some/local/path")) def test_non_remote_paths(self): self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) @@ -717,19 +683,9 @@ def test_special_characters_in_path(self): self.assertTrue(file_utils.is_remote_path("gs://some/path?query=value")) self.assertTrue(file_utils.is_remote_path("gs://some/path with spaces")) - def test_unsupported_protocols(self): - self.assertFalse(file_utils.is_remote_path("mailto:user@example.com")) - self.assertFalse( - file_utils.is_remote_path( - "data:text/plain;charset=utf-8,Hello%20World!" - ) - ) - self.assertFalse(file_utils.is_remote_path("file://local/path")) - def test_case_sensitivity(self): self.assertTrue(file_utils.is_remote_path("gs://sOme/Path")) - self.assertTrue(file_utils.is_remote_path("HTTP://eXample.Com")) - self.assertTrue(file_utils.is_remote_path("hTtP://exaMple.cOm")) + self.assertTrue(file_utils.is_remote_path("hdfS://some/Path")) def test_whitespace_in_paths(self): self.assertTrue(file_utils.is_remote_path(" gs://some/path ")) From 021068b90774078eb080c7727190d9cb1e3f0375 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Wed, 20 Sep 2023 20:01:08 +0000 Subject: [PATCH 39/39] Fix `is_remote_path` --- keras_core/utils/file_utils.py | 23 ++++++++-------- keras_core/utils/file_utils_test.py | 41 ++++++++--------------------- 2 files changed, 22 insertions(+), 42 deletions(-) diff --git a/keras_core/utils/file_utils.py b/keras_core/utils/file_utils.py index 0ee275a94..0c9be9eda 100644 --- a/keras_core/utils/file_utils.py +++ b/keras_core/utils/file_utils.py @@ -386,21 +386,20 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): def is_remote_path(filepath): - filepath_str = str(filepath).strip() - supported_patterns = [re.compile(r"^(gs|hdfs)://.*$", re.IGNORECASE)] + """ + Determines if a given filepath indicates a remote location. - for pattern in supported_patterns: - if pattern.match(filepath_str): - return True + This function checks if the filepath represents a known remote pattern + such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`) - # Log or print the error message without raising an exception - warning_msg = ( - f"Warning: The path '{filepath_str}' is not recognized as a " - f"supported remote path by gfile. Supported paths are: " - f"{', '.join(['gs://', 'hdfs://'])}" - ) - print(warning_msg) + Args: + filepath (str): The path to be checked. + Returns: + bool: True if the filepath is a recognized remote path, otherwise False + """ + if re.match(r"^(/cns|/cfs|/gcs|/hdfs|.*://).*$", str(filepath)): + return True return False diff --git a/keras_core/utils/file_utils_test.py b/keras_core/utils/file_utils_test.py index 06327830f..3cb95da0a 100644 --- a/keras_core/utils/file_utils_test.py +++ b/keras_core/utils/file_utils_test.py @@ -652,14 +652,20 @@ def test_resolve_hasher_default(self): class IsRemotePathTest(test_case.TestCase): - def test_gs_remote_path(self): - self.assertFalse(file_utils.is_remote_path("/gs/some/path/to/file.txt")) - self.assertFalse(file_utils.is_remote_path("/gs/another/directory/")) - self.assertTrue(file_utils.is_remote_path("gs://bucket/some/file.txt")) + def test_gcs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) + self.assertTrue(file_utils.is_remote_path("gcs://bucket/some/file.txt")) def test_hdfs_remote_path(self): self.assertTrue(file_utils.is_remote_path("hdfs://some/path/on/hdfs")) - self.assertFalse(file_utils.is_remote_path("/hdfs/some/local/path")) + self.assertTrue(file_utils.is_remote_path("/hdfs/some/local/path")) + + def test_cns_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cns/some/path")) + + def test_cfs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cfs/some/path")) def test_non_remote_paths(self): self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) @@ -672,31 +678,6 @@ def test_non_remote_paths(self): self.assertFalse(file_utils.is_remote_path("./relative/path")) self.assertFalse(file_utils.is_remote_path("~/relative/path")) - def test_edge_cases(self): - self.assertFalse(file_utils.is_remote_path("")) - self.assertFalse(file_utils.is_remote_path(None)) - self.assertFalse(file_utils.is_remote_path(12345)) - - def test_special_characters_in_path(self): - self.assertTrue(file_utils.is_remote_path("gs://some/päth")) - self.assertTrue(file_utils.is_remote_path("gs://some/path#anchor")) - self.assertTrue(file_utils.is_remote_path("gs://some/path?query=value")) - self.assertTrue(file_utils.is_remote_path("gs://some/path with spaces")) - - def test_case_sensitivity(self): - self.assertTrue(file_utils.is_remote_path("gs://sOme/Path")) - self.assertTrue(file_utils.is_remote_path("hdfS://some/Path")) - - def test_whitespace_in_paths(self): - self.assertTrue(file_utils.is_remote_path(" gs://some/path ")) - self.assertTrue(file_utils.is_remote_path("gs:// some /path")) - - def test_false_positives(self): - self.assertFalse(file_utils.is_remote_path("/httpslocal/some/path")) - self.assertFalse(file_utils.is_remote_path("/gslocal/some/path")) - self.assertFalse(file_utils.is_remote_path("/cnslocal/some/path")) - self.assertFalse(file_utils.is_remote_path("/cfslocal/some/path")) - class TestRaiseIfNoGFile(test_case.TestCase): def test_raise_if_no_gfile_raises_correct_message(self):