Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

argmax returns incorrect result for input containing -0.0 (Keras using TensorFlow backend) #20350

Open
LilyDong0127 opened this issue Oct 14, 2024 · 0 comments
Assignees

Comments

@LilyDong0127
Copy link

Description:
When using keras.backend.argmax with an input array containing -0.0, the result is incorrect. Specifically, the function returns 1 (the index of -0.0) as the position of the maximum value, while the actual maximum value is 1.401298464324817e-45 at index 2.

This issue is reproducible in TensorFlow and JAX as well, as they share similar backend logic for the argmax function. However, PyTorch correctly returns the expected index 2 for the maximum value.

Expected Behavior:
keras.backend.argmax should return 2, as the value at index 2 (1.401298464324817e-45) is greater than both -1.0 and -0.0.

import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp
from tensorflow import keras

def test_argmax():
    # Input data
    input_data = np.array([-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32)

    # PyTorch argmax
    pytorch_result = torch.argmax(torch.tensor(input_data, dtype=torch.float32)).item()
    print(f"PyTorch argmax result: {pytorch_result}")

    # TensorFlow argmax
    tensorflow_result = tf.math.argmax(input_data).numpy()
    print(f"TensorFlow argmax result: {tensorflow_result}")

    # Keras argmax (Keras internally uses TensorFlow, so should be the same)
    keras_result = keras.backend.argmax(input_data).numpy()
    print(f"Keras argmax result: {keras_result}")

    # JAX argmax
    jax_result = jnp.argmax(input_data)
    print(f"JAX argmax result: {jax_result}")

if __name__ == "__main__":
    test_argmax()

PyTorch argmax result: 2
TensorFlow argmax result: 1
Keras argmax result: 1
JAX argmax result: 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants