diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 0a6ad7967..164a0440b 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -609,15 +609,9 @@ def maybe_convert(x): # Get signature default value training = call_spec.arguments_dict.get("training", None) call_context.training = training - if self._call_has_training_arg(): - if training is not None: - # Only populate arg if it has a concrete value - kwargs["training"] = training - elif "training" in kwargs: - # In some cases the value None may have been passed explicitly - # (e.g. by a parent Sequential). - # That's invalid, so don't propagate it. - kwargs.pop("training") + if self._call_has_training_arg() and training is not None: + # Only populate arg if it has a concrete value + kwargs["training"] = training ############################## # 6. Populate mask argument(s) diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 54de5c78c..ea7949647 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -543,6 +543,7 @@ def call(*args, **kwargs): if ( hasattr(operation, "_call_has_training_arg") and operation._call_has_training_arg() + and training is not None ): kwargs["training"] = training return operation(*args, **kwargs) diff --git a/keras_core/models/sequential.py b/keras_core/models/sequential.py index b6e01cd08..580e3c479 100644 --- a/keras_core/models/sequential.py +++ b/keras_core/models/sequential.py @@ -179,7 +179,7 @@ def call(self, inputs, training=None, mask=None): kwargs = {} if layer._call_has_mask_arg(): kwargs["mask"] = mask - if layer._call_has_training_arg(): + if layer._call_has_training_arg() and training is not None: kwargs["training"] = training outputs = layer(inputs, **kwargs) inputs = outputs