You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description:
When using keras.backend.argmax with an input array containing -0.0, the result is incorrect. Specifically, the function returns 1 (the index of -0.0) as the position of the maximum value, while the actual maximum value is 1.401298464324817e-45 at index 2.
This issue is reproducible in TensorFlow and JAX as well, as they share similar backend logic for the argmax function. However, PyTorch correctly returns the expected index 2 for the maximum value.
Expected Behavior:
keras.backend.argmax should return 2, as the value at index 2 (1.401298464324817e-45) is greater than both -1.0 and -0.0.
import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp
from tensorflow import keras
def test_argmax():
# Input data
input_data = np.array([-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32)
# PyTorch argmax
pytorch_result = torch.argmax(torch.tensor(input_data, dtype=torch.float32)).item()
print(f"PyTorch argmax result: {pytorch_result}")
# TensorFlow argmax
tensorflow_result = tf.math.argmax(input_data).numpy()
print(f"TensorFlow argmax result: {tensorflow_result}")
# Keras argmax (Keras internally uses TensorFlow, so should be the same)
keras_result = keras.backend.argmax(input_data).numpy()
print(f"Keras argmax result: {keras_result}")
# JAX argmax
jax_result = jnp.argmax(input_data)
print(f"JAX argmax result: {jax_result}")
if __name__ == "__main__":
test_argmax()
Description:
When using keras.backend.argmax with an input array containing -0.0, the result is incorrect. Specifically, the function returns 1 (the index of -0.0) as the position of the maximum value, while the actual maximum value is 1.401298464324817e-45 at index 2.
This issue is reproducible in TensorFlow and JAX as well, as they share similar backend logic for the argmax function. However, PyTorch correctly returns the expected index 2 for the maximum value.
Expected Behavior:
keras.backend.argmax should return 2, as the value at index 2 (1.401298464324817e-45) is greater than both -1.0 and -0.0.
The text was updated successfully, but these errors were encountered: