Skip to content

Commit

Permalink
fix sequential serialization error
Browse files Browse the repository at this point in the history
  • Loading branch information
haohuanw committed Jul 28, 2024
1 parent d0a17fa commit 68c2565
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
11 changes: 5 additions & 6 deletions keras/src/backend/torch/layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Iterator
from typing import Tuple

Expand All @@ -25,11 +24,6 @@ class TorchLayer(torch.nn.Module):
1. Populate all sublayers torch params by calling _track_torch_params()
2. Create a single torch.nn.ParameterList() parameter with trainable,
non trainable and seed generator states belongs to the current layer.
Since keras also allows untrack / track object post build, eg.
Dense.enable_lora(), Dense.quantization(); _untrack_torch_params() is added
that allows refresh the parameters expose to torch module. A re-populate
will trigger every time when Layer._track_variable() and
Layer._untrack_variable() is called.
Few additional points that user should be aware of:
1. When torch backend is enabled KerasVariable.value is torch.nn.Parameter,
Expand All @@ -56,6 +50,11 @@ class TorchLayer(torch.nn.Module):
corresponding parameter in torch_params from a keras variable:
parameters = [(pname, p) for pname, p in layer.named_parameters() \
if id(p) == id(variable.value)]
7. For non trainable varialbes like mean and var in BatchNormalization, this
is registered as part of torch_params as parameters instead of buffers.
This is not really torch best practices but it is not really possible in
keras to track since keras doesn't distinguish a variable that is a stats
or just have gradient skipped.
"""

def _track_torch_params(self):
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def append_to_outputs(batch_outputs, outputs):
# should refactored to not require _compile_metrics and optimizer
# is defined.
self._compile_metrics = None
self._compile_loss = None
self.optimizer = None
self._symbolic_build(iterator=epoch_iterator)

Expand Down
3 changes: 2 additions & 1 deletion keras/src/models/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def add(self, layer, rebuild=True):
f"add a different Input layer to it."
)

self._layers.append(layer)
# append will not trigger __setattr__ for tracking purpose.
self._layers = self._layers + [layer]
if rebuild:
self._maybe_rebuild()
else:
Expand Down
4 changes: 3 additions & 1 deletion keras/src/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def test_partial_load(self):
]
)
new_layer_kernel_value = np.array(new_model.layers[1].kernel)
with self.assertRaisesRegex(ValueError, "must match"):
with self.assertRaisesRegex(
ValueError, "A total of 1 objects could not be loaded"
):
# Doesn't work by default
new_model.load_weights(temp_filepath)
# Now it works
Expand Down

0 comments on commit 68c2565

Please sign in to comment.