From b600c6d9d531482637f90ec802403252fa6252fa Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 29 Sep 2024 22:08:38 +0800 Subject: [PATCH] Allow Variable as loss weights --- keras/src/trainers/compile_utils.py | 11 +++++++---- keras/src/trainers/compile_utils_test.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 410a782dbcd..3b83209369a 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -1,3 +1,4 @@ +from keras.src import backend from keras.src import losses as losses_module from keras.src import metrics as metrics_module from keras.src import ops @@ -413,8 +414,8 @@ def __init__( reduction="sum_over_batch_size", output_names=None, ): - if loss_weights and not isinstance( - loss_weights, (list, tuple, dict, float) + if loss_weights is not None and not isinstance( + loss_weights, (list, tuple, dict, float, backend.Variable) ): raise ValueError( "Expected `loss_weights` argument to be a float " @@ -517,12 +518,14 @@ def build(self, y_true, y_pred): else: flat_loss_weights = tree.flatten(loss_weights) for loss_weight in flat_loss_weights: - if not isinstance(loss_weight, (int, float, type(None))): + if not isinstance( + loss_weight, (int, float, type(None), backend.Variable) + ): raise TypeError( "When providing the `loss_weights` argument, each " "element should be a Python int, float (the weighting " "coefficient corresponding to the loss for that " - "output) or `None`." + "output), `None` or `Variable`." f"Received: loss_weights={loss_weights}" ) if len(flat_loss_weights) != len(flat_losses): diff --git a/keras/src/trainers/compile_utils_test.py b/keras/src/trainers/compile_utils_test.py index cf0dd8aeab6..53ed3f8e0dd 100644 --- a/keras/src/trainers/compile_utils_test.py +++ b/keras/src/trainers/compile_utils_test.py @@ -347,3 +347,20 @@ def test_list_loss_dict_data(self): } value = compile_loss(y_true, y_pred) self.assertAllClose(value, 1.07666, atol=1e-5) + + def test_variable_as_loss_weights(self): + loss_weights = backend.Variable(initializer="zeros", shape=()) + compile_loss = CompileLoss(loss="mse", loss_weights=loss_weights) + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + compile_loss.build(y_true, y_pred) + # `loss_weights` is set to `0.0`. + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.0, atol=1e-5) + # Test changing the `loss_weights` value + loss_weights.assign(1.0) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.068333, atol=1e-5) + loss_weights.assign(1.5) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.068333 * 1.5, atol=1e-5)