-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow keras.Variable
for loss weights
#20306
Allow keras.Variable
for loss weights
#20306
Conversation
cb11e26
to
b600c6d
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20306 +/- ##
=======================================
Coverage 78.83% 78.83%
=======================================
Files 511 511
Lines 48989 48990 +1
Branches 9022 9022
=======================================
+ Hits 38621 38622 +1
Misses 8505 8505
Partials 1863 1863
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Thanks for the PR! Can you add as a unit test something similar to the code snippet you posted? Once thing I'd like to guard for, is to make sure that the arrays attached to those variables can pass the compilation boundary and can be modified afterwards. |
Thanks for pointing that out. It seems the following snippet works fine with tf and torch but fails with jax: import numpy as np
import keras
from keras import callbacks
from keras import layers
from keras import models
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
inputs = layers.Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(inputs)
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(inputs)
model = models.Model(inputs, [output_a, output_b])
class LossWeightsAdjuster(callbacks.Callback):
def __init__(self, loss_weights):
self.a = loss_weights[0]
self.b = loss_weights[1]
def on_epoch_begin(self, epoch, logs):
self.a.assign(float(epoch % 2))
self.b.assign(float((epoch + 1) % 2))
loss_weights = [
keras.Variable(initializer="ones", shape=()),
keras.Variable(initializer="ones", shape=()),
]
loss_weights_adjuster = LossWeightsAdjuster(loss_weights)
model.compile(
optimizer="sgd",
loss=["mean_squared_error", "binary_crossentropy"],
loss_weights=loss_weights,
)
hist = model.fit(
x, (y1, y2), batch_size=2, epochs=2, callbacks=[loss_weights_adjuster]
)
output_a_loss = hist.history["output_a_loss"]
output_b_loss = hist.history["output_b_loss"]
np.testing.assert_allclose(output_a_loss[0], 0.0)
np.testing.assert_allclose(output_b_loss[1], 0.0) The root cause might be that the states in callbacks doesn't propagate to |
I think this is too difficult and messy to support in JAX. What I would suggest for this workflow is to update the weights in Python (as floats) and recompile the model when you want the new weight values to be taken into account. |
Got it. This PR should be closed now. |
Fix #20294
With this PR, we can use
keras.Variable
for loss weights:We can see that the individual losses correspond to the loss weights.