Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jul 6, 2023
1 parent 8f1b43e commit eb1bb72
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 10 deletions.
12 changes: 3 additions & 9 deletions keras_core/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,15 +609,9 @@ def maybe_convert(x):
# Get signature default value
training = call_spec.arguments_dict.get("training", None)
call_context.training = training
if self._call_has_training_arg():
if training is not None:
# Only populate arg if it has a concrete value
kwargs["training"] = training
elif "training" in kwargs:
# In some cases the value None may have been passed explicitly
# (e.g. by a parent Sequential).
# That's invalid, so don't propagate it.
kwargs.pop("training")
if self._call_has_training_arg() and training is not None:
# Only populate arg if it has a concrete value
kwargs["training"] = training

##############################
# 6. Populate mask argument(s)
Expand Down
1 change: 1 addition & 0 deletions keras_core/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def call(*args, **kwargs):
if (
hasattr(operation, "_call_has_training_arg")
and operation._call_has_training_arg()
and training is not None
):
kwargs["training"] = training
return operation(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion keras_core/models/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def call(self, inputs, training=None, mask=None):
kwargs = {}
if layer._call_has_mask_arg():
kwargs["mask"] = mask
if layer._call_has_training_arg():
if layer._call_has_training_arg() and training is not None:
kwargs["training"] = training
outputs = layer(inputs, **kwargs)
inputs = outputs
Expand Down

0 comments on commit eb1bb72

Please sign in to comment.