-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] ModelCheckpoint Callback results in an error due to unsupported overwrite=True parameter #1242
Comments
Same problem happens to me on latest merlin tensorflow container (23.08). I believe the solution is to create a new class MerlinCheckPoint inheriting from ModelCheckpoint and overriding the |
I implemented a custom callback to handle this issue. However, I just encountered the fact that the |
What is the status of this issue? Is there any update on the code? |
@CarloNicolini I implemented a custom callback where I call the |
Agree, there are many issues with no response since weeks, I also never managed to make the By the way, could you please share your custom callback? from keras.callbacks import ModelCheckpoint
class MerlinCheckpoint(ModelCheckpoint):
# custom override of ModelCheckpoint _save_model made to operate on Merlin BaseModel rather than on standard Keras model
def _save_model(self, epoch, batch, logs):
"""Saves the model.
Args:
epoch: the epoch this iteration is in.
batch: the batch this iteration is in. `None` if the `save_freq`
is set to `epoch`.
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
"""
logs = logs or {}
if (
isinstance(self.save_freq, int)
or self.epochs_since_last_save >= self.period
):
# Block only when saving interval is reached.
logs = tf_utils.sync_to_numpy_or_python_type(logs)
self.epochs_since_last_save = 0
filepath = self._get_file_path(epoch, batch, logs)
# Create host directory if it doesn't exist.
dirname = os.path.dirname(filepath) # noqa: PTH120
if dirname and not tf.io.gfile.exists(dirname):
tf.io.gfile.makedirs(dirname)
try:
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
logger.warning(
"Can save best model only with %s available, " "skipping.",
self.monitor,
)
elif self.monitor_op(current, self.best):
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: {self.monitor} "
"improved "
f"from {self.best:.5f} to {current:.5f}, "
f"saving model to {filepath}"
)
self.best = current
if self.save_weights_only:
self.model.save_weights(
filepath,
# overwrite=True,
#options=self._options,
)
else:
self.model.save(
filepath,
# overwrite=True,
# options=self._options,
)
elif self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: "
f"{self.monitor} did not improve "
f"from {self.best:.5f}"
)
else:
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: saving model to {filepath}"
)
if self.save_weights_only:
self.model.save_weights(
filepath,
# overwrite=True,
# options=self._options
)
else:
self.model.save(
filepath,
# overwrite=True,
# options=self._options
)
self._maybe_remove_file()
except IsADirectoryError: # h5py 3.x
msg = (
"Please specify a non-directory filepath for "
"ModelCheckpoint. Filepath used is an existing "
f"directory: {filepath}"
)
raise OSError(msg) # noqa: B904
except IOError as e: # h5py 2.x
# `e.errno` appears to be `None` so checking the content of
# `e.args[0]`.
if "is a directory" in str(e.args[0]).lower():
msg = (
"Please specify a non-directory filepath for "
"ModelCheckpoint. Filepath used is an existing "
f"directory: f{filepath}"
)
raise OSError(msg) # noqa: B904
# Re-throw the error for any other causes.
raise e |
@CarloNicolini sorry I can't share the code as I developed it during my work and it is not my property. However, I just asked from Chat-GPT 3.5 a basic Tensorflow model saving template which I just modified a bit for my own use-case. Compared to the official |
@CarloNicolini @hkristof03 can u please try to use our nightly container Beyond that we always need a proper small minimal reproducible example code to repro the errors. you can generate a small fake data and share your entire code to repro. thanks. |
Bug description
The callback ModelCheckpoint calls the model's save method, which is overwritten by the Merlin BaseModel implementation. The callback passes the
overwrite=True
parameter to the save method, which is not supported and raises an error.Steps/Code to reproduce bug
ModelCheckpoint
callback.model.fit()
method.The text was updated successfully, but these errors were encountered: