Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-output model training broken due to unexpected tree.flatten behavior #20346

Open
gustavoeb opened this issue Oct 13, 2024 · 0 comments
Open
Assignees

Comments

@gustavoeb
Copy link

gustavoeb commented Oct 13, 2024

TLDR; as of Keras 3.5 training a functional model with multiple outputs, when passing loss and y_true as dicts, is broken. Seems like tree.flatten is re-ordering entries in the y_true dict.

Keras version: 3.5+
Backend: all

Repro code:

from numpy.random import randn
from keras import layers, Model

if __name__ == "__main__":
    input_a = layers.Input((10,), name="a")
    input_b = layers.Input((5,), name="b")
    output_c = layers.Reshape([5,4], name="c")(layers.Dense(20)(input_a))
    output_d = layers.Dense(1, name="d")(input_b)
    
    model = Model([input_a, input_b], [output_d, output_c])
    model.compile(loss={"d":"mse", "c":"mae"}, optimizer="adam")

    model.fit(
        {"a":randn(32, 10), "b":randn(32, 5)},
        {"d":randn(32, 1), "c":randn(32, 5, 4)},
    )

To the best of my understanding this is pretty valid code, and until 3.4 it worked. Now it seems like this line is re-ordering y_true:

y_true = tree.flatten(y_true)

For this one example y_true becomes [c,d] while y_pred is [d,c]. It seemed alphabetical in my couple attempts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants