Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply backend.result_type to bincount, substract, matmul, multiply, mean and max #18534

Merged
merged 12 commits into from
Oct 3, 2023

Conversation

james77777778
Copy link
Contributor

Followed by #18482

This PR has applied result_type to the following ops:

  • bincount
  • substract
  • matmul
  • multiply
  • mean
  • max

The corresponding unit tests have also been added.

@james77777778
Copy link
Contributor Author

All tests passed except the failure in Codecov.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

keras/backend/jax/numpy.py Outdated Show resolved Hide resolved
keras/backend/tensorflow/numpy.py Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2023

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (43be5fc) 77.58% compared to head (683a976) 77.67%.
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18534      +/-   ##
==========================================
+ Coverage   77.58%   77.67%   +0.09%     
==========================================
  Files         334      334              
  Lines       32211    32302      +91     
  Branches     6286     6297      +11     
==========================================
+ Hits        24990    25092     +102     
+ Misses       5636     5631       -5     
+ Partials     1585     1579       -6     
Flag Coverage Δ
keras 77.58% <99.22%> (+0.09%) ⬆️
keras-jax 63.23% <28.68%> (-0.07%) ⬇️
keras-numpy 57.28% <42.63%> (+<0.01%) ⬆️
keras-tensorflow 63.23% <49.61%> (+0.01%) ⬆️
keras-torch 64.00% <41.86%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/backend/jax/numpy.py 98.89% <100.00%> (+0.01%) ⬆️
keras/backend/numpy/numpy.py 98.61% <100.00%> (+0.08%) ⬆️
keras/backend/torch/numpy.py 94.86% <100.00%> (-0.02%) ⬇️
keras/ops/numpy.py 95.34% <100.00%> (+0.48%) ⬆️
keras/backend/tensorflow/numpy.py 95.38% <97.77%> (+0.35%) ⬆️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

dtype = getattr(x, "dtype", None)
if hasattr(dtype, "name") and "float" in dtype.name:
return cast(outputs, dtype)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be result_type(x.dtype, config.floatx()) rather than hardcoding float32?

Copy link
Contributor Author

@james77777778 james77777778 Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a note:

    # `jnp.mean` does not handle low precision (e.g., float16) overflow
    # correctly, so we compute with float32 and cast back to the original type.

it should come from this PR: keras-team/keras-core#410

I have added a test to verify the overflow behavior:

# test overflow
x = np.array([65504, 65504, 65504], dtype="float16")
self.assertAllClose(knp.mean(x), np.mean(x))
np.mean(x) jnp.mean(x) jnp.mean(x, dtype="float32") tfnp.mean(x) tfnp.mean(x, dtype="float32") torch.mean(x) torch.mean(x, dtype=torch.float32)
65504 inf 65504 inf 65504 inf 65504

As a result, we should use float32 for jax, tensorflow and torch to compute mean, even if backend.floatx() == "float16"

return np.mean(x, axis=axis, keepdims=keepdims)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here

return tfnp.mean(x, axis=axis, keepdims=keepdims)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here

x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
return torch.mean(x, axis=axis, keepdims=keepdims)
ori_dtype = standardize_dtype(x.dtype)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great contribution! LGTM.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 3, 2023
@fchollet fchollet merged commit 3d3a378 into keras-team:master Oct 3, 2023
6 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 3, 2023
@james77777778 james77777778 deleted the apply-result-type branch October 3, 2023 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants