Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras argmin function returns incorrect index when handling subnormal float values. #20355

Open
LilyDong0127 opened this issue Oct 15, 2024 · 1 comment
Assignees
Labels
keras-team-review-pending Pending review by a Keras team member.

Comments

@LilyDong0127
Copy link

When using the Keras backend argmin function on an input array containing subnormal float values, Keras consistently returns the index of 0.0 as the minimum value, even though a smaller subnormal value (-1.401298464324817e-45) exists in the array. Other deep learning frameworks such as PyTorch and Chainer correctly return the index of the subnormal value, but Keras (and TensorFlow) return the index of 0.

Expected Behavior:

The expected behavior is for Keras's argmin function to return the index of the smallest value, which should be the subnormal float value (-1.401298464324817e-45) at index 2. Instead, Keras is returning the index of 0.0 (index 0).

Reproduction Code:

import torch
import tensorflow as tf
import numpy as np
from chainer import functions as F
import jax.numpy as jnp
import tensorflow.keras.backend as K

# Input data
input_data = [
    0.0,
    1.1754943508222875e-38,
    -1.401298464324817e-45,
    0.0,
    459367.0
]

# Test PyTorch
def test_pytorch_argmin(input_data):
    tensor = torch.tensor(input_data, dtype=torch.float32)
    result = torch.argmin(tensor).item()
    print(f"PyTorch argmin result: {result}")
    return result

# Test TensorFlow
def test_tensorflow_argmin(input_data):
    tensor = tf.constant(input_data, dtype=tf.float32)
    result = tf.argmin(tensor).numpy()
    print(f"TensorFlow argmin result: {result}")
    return result

# Test Keras using backend
def test_keras_argmin(input_data):
    tensor = K.constant(input_data, dtype=tf.float32)
    result = K.argmin(tensor, axis=-1).numpy()
    print(f"Keras argmin result: {result}")
    return result

# Test Chainer
def test_chainer_argmin(input_data):
    tensor = np.array(input_data, dtype=np.float32)
    result = F.argmin(tensor).data
    print(f"Chainer argmin result: {result}")
    return result

# Test JAX
def test_jax_argmin(input_data):
    tensor = jnp.array(input_data, dtype=jnp.float32)
    result = jnp.argmin(tensor).item()
    print(f"JAX argmin result: {result}")
    return result

if __name__ == "__main__":
    pytorch_result = test_pytorch_argmin(input_data)
    tensorflow_result = test_tensorflow_argmin(input_data)
    keras_result = test_keras_argmin(input_data)
    chainer_result = test_chainer_argmin(input_data)
    jax_result = test_jax_argmin(input_data)

    print("\nSummary of results:")
    print(f"PyTorch argmin: {pytorch_result}")
    print(f"TensorFlow argmin: {tensorflow_result}")
    print(f"Keras argmin: {keras_result}")
    print(f"Chainer argmin: {chainer_result}")
    print(f"JAX argmin: {jax_result}")

Summary of results:
PyTorch argmin: 2
TensorFlow argmin: 0
Keras argmin: 0
Chainer argmin: 2
JAX argmin: 0
@sachinprasadhs sachinprasadhs added the keras-team-review-pending Pending review by a Keras team member. label Oct 16, 2024
@sachinprasadhs
Copy link
Collaborator

I checked the result with different backend using Keras 3, torch results in 2 where as TensorFlow results in 0.
The results are not consistent across different backend.

import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import keras

# Input data
input_data = [
    0.0,
    1.1754943508222875e-38,
    -1.401298464324817e-45,
    0.0,
    459367.0
]

def test_keras_argmin(input_data):
    result = keras.ops.argmin(input_data, axis=-1).numpy()
    print(f"Keras argmin result: {result}")
    return result

test_keras_argmin(input_data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras-team-review-pending Pending review by a Keras team member.
Projects
None yet
Development

No branches or pull requests

2 participants