Skip to content

Commit

Permalink
Allow Variable as loss weights
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Sep 29, 2024
1 parent e80fe5b commit b600c6d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
11 changes: 7 additions & 4 deletions keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras.src import backend
from keras.src import losses as losses_module
from keras.src import metrics as metrics_module
from keras.src import ops
Expand Down Expand Up @@ -413,8 +414,8 @@ def __init__(
reduction="sum_over_batch_size",
output_names=None,
):
if loss_weights and not isinstance(
loss_weights, (list, tuple, dict, float)
if loss_weights is not None and not isinstance(
loss_weights, (list, tuple, dict, float, backend.Variable)
):
raise ValueError(
"Expected `loss_weights` argument to be a float "
Expand Down Expand Up @@ -517,12 +518,14 @@ def build(self, y_true, y_pred):
else:
flat_loss_weights = tree.flatten(loss_weights)
for loss_weight in flat_loss_weights:
if not isinstance(loss_weight, (int, float, type(None))):
if not isinstance(
loss_weight, (int, float, type(None), backend.Variable)
):
raise TypeError(
"When providing the `loss_weights` argument, each "
"element should be a Python int, float (the weighting "
"coefficient corresponding to the loss for that "
"output) or `None`."
"output), `None` or `Variable`."
f"Received: loss_weights={loss_weights}"
)
if len(flat_loss_weights) != len(flat_losses):
Expand Down
17 changes: 17 additions & 0 deletions keras/src/trainers/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,20 @@ def test_list_loss_dict_data(self):
}
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 1.07666, atol=1e-5)

def test_variable_as_loss_weights(self):
loss_weights = backend.Variable(initializer="zeros", shape=())
compile_loss = CompileLoss(loss="mse", loss_weights=loss_weights)
y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])
compile_loss.build(y_true, y_pred)
# `loss_weights` is set to `0.0`.
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 0.0, atol=1e-5)
# Test changing the `loss_weights` value
loss_weights.assign(1.0)
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 0.068333, atol=1e-5)
loss_weights.assign(1.5)
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 0.068333 * 1.5, atol=1e-5)

0 comments on commit b600c6d

Please sign in to comment.