Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow keras.Variable for loss weights #20306

Conversation

james77777778
Copy link
Contributor

Fix #20294

With this PR, we can use keras.Variable for loss weights:

import numpy as np

import keras
from keras import layers
from keras import models

x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))

inputs = layers.Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(inputs)
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(inputs)
model = models.Model(inputs, [output_a, output_b])

loss_weights = [
    keras.Variable(initializer="ones", shape=()),
    keras.Variable(initializer="ones", shape=()),
]
model.compile(
    optimizer="sgd",
    loss=["mean_squared_error", "binary_crossentropy"],
    loss_weights=loss_weights,
)

loss_weights[0].assign(0.00001)
loss_weights[1].assign(1.0)
model.fit(x, (y1, y2), batch_size=2, epochs=1)

# Change the values
loss_weights[0].assign(1.0)
loss_weights[1].assign(0.00001)
model.fit(x, (y1, y2), batch_size=2, epochs=1)
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.7180 - output_a_loss: 6.9707e-06 - output_b_loss: 0.7180  
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.6000 - output_a_loss: 0.5999 - output_b_loss: 6.9904e-06

We can see that the individual losses correspond to the loss weights.

@codecov-commenter
Copy link

codecov-commenter commented Sep 29, 2024

Codecov Report

Attention: Patch coverage is 33.33333% with 2 lines in your changes missing coverage. Please review.

Project coverage is 78.83%. Comparing base (e80fe5b) to head (b600c6d).

Files with missing lines Patch % Lines
keras/src/trainers/compile_utils.py 33.33% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20306   +/-   ##
=======================================
  Coverage   78.83%   78.83%           
=======================================
  Files         511      511           
  Lines       48989    48990    +1     
  Branches     9022     9022           
=======================================
+ Hits        38621    38622    +1     
  Misses       8505     8505           
  Partials     1863     1863           
Flag Coverage Δ
keras 78.69% <33.33%> (+<0.01%) ⬆️
keras-jax 62.29% <33.33%> (+<0.01%) ⬆️
keras-numpy 57.44% <33.33%> (+0.01%) ⬆️
keras-tensorflow 63.57% <33.33%> (+<0.01%) ⬆️
keras-torch 62.28% <33.33%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Member

Thanks for the PR! Can you add as a unit test something similar to the code snippet you posted? Once thing I'd like to guard for, is to make sure that the arrays attached to those variables can pass the compilation boundary and can be modified afterwards.

@james77777778
Copy link
Contributor Author

Thanks for pointing that out. It seems the following snippet works fine with tf and torch but fails with jax:

import numpy as np

import keras
from keras import callbacks
from keras import layers
from keras import models

x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))

inputs = layers.Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(inputs)
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(inputs)
model = models.Model(inputs, [output_a, output_b])


class LossWeightsAdjuster(callbacks.Callback):
    def __init__(self, loss_weights):
        self.a = loss_weights[0]
        self.b = loss_weights[1]

    def on_epoch_begin(self, epoch, logs):
        self.a.assign(float(epoch % 2))
        self.b.assign(float((epoch + 1) % 2))


loss_weights = [
    keras.Variable(initializer="ones", shape=()),
    keras.Variable(initializer="ones", shape=()),
]
loss_weights_adjuster = LossWeightsAdjuster(loss_weights)
model.compile(
    optimizer="sgd",
    loss=["mean_squared_error", "binary_crossentropy"],
    loss_weights=loss_weights,
)
hist = model.fit(
    x, (y1, y2), batch_size=2, epochs=2, callbacks=[loss_weights_adjuster]
)
output_a_loss = hist.history["output_a_loss"]
output_b_loss = hist.history["output_b_loss"]
np.testing.assert_allclose(output_a_loss[0], 0.0)
np.testing.assert_allclose(output_b_loss[1], 0.0)

The root cause might be that the states in callbacks doesn't propagate to *_step in JAXTrainer.
Should we support this feature? I may need to add some codes in JAXTrainer.

@fchollet
Copy link
Member

Should we support this feature? I may need to add some codes in JAXTrainer.

I think this is too difficult and messy to support in JAX.

What I would suggest for this workflow is to update the weights in Python (as floats) and recompile the model when you want the new weight values to be taken into account.

@james77777778
Copy link
Contributor Author

I think this is too difficult and messy to support in JAX.

What I would suggest for this workflow is to update the weights in Python (as floats) and recompile the model when you want the new weight values to be taken into account.

Got it. This PR should be closed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Closed/Rejected
Development

Successfully merging this pull request may close these issues.

loss_weights depending on epoch number
4 participants