Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce dtype inference and improve dtype in ops.numpy.* #938

Closed
21 changes: 15 additions & 6 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@
return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial)


def ones(shape, dtype="float32"):
def ones(shape, dtype=None):
dtype = dtype or config.floatx()
return jnp.ones(shape, dtype=dtype)


def zeros(shape, dtype="float32"):
def zeros(shape, dtype=None):
dtype = dtype or config.floatx()
return jnp.zeros(shape, dtype=dtype)


Expand Down Expand Up @@ -253,7 +255,8 @@
return jnp.dot(x, y)


def empty(shape, dtype="float32"):
def empty(shape, dtype=None):
dtype = dtype or config.floatx()

Check warning on line 259 in keras_core/backend/jax/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/numpy.py#L259

Added line #L259 was not covered by tests
return jnp.empty(shape, dtype=dtype)


Expand Down Expand Up @@ -284,6 +287,7 @@


def full(shape, fill_value, dtype=None):
dtype = dtype or config.floatx()
return jnp.full(shape, fill_value, dtype=dtype)


Expand All @@ -307,7 +311,8 @@
return jnp.hstack(xs)


def identity(n, dtype="float32"):
def identity(n, dtype=None):
dtype = dtype or config.floatx()
return jnp.identity(n, dtype=dtype)


Expand Down Expand Up @@ -348,6 +353,7 @@
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
dtype = dtype or config.floatx()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things like this will deviate from the NumPy convention in the sense that NumPy tries to infer the dtype from argument dtypes. IMO defaulting to float32 is much better: simpler, more consistent. So I think we can go with it.

However if we're going to make this deviation, we should do it consistently, in all ops that infer output dtype from argument dtype, such as arange.

The alternative is to stick to the NumPy dtype inference convention (but with float32 instead of float64).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should stick to the JAX dtype inference convention instead of NumPy, as it should be better suited for DL. What do you think?

We can consider reimplementing jnp.result_dtype for all backends
https://github.com/google/jax/blob/2cba122bbe512f7927d165fdbb29108dcf0fe124/jax/_src/dtypes.py#L638

It may require some time if we decide to do so.

return jnp.linspace(
start,
stop,
Expand Down Expand Up @@ -398,6 +404,7 @@


def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
dtype = dtype or config.floatx()
return jnp.logspace(
start,
stop,
Expand Down Expand Up @@ -573,7 +580,8 @@
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)


def tri(N, M=None, k=0, dtype="float32"):
def tri(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return jnp.tri(N, M=M, k=k, dtype=dtype)


Expand Down Expand Up @@ -652,7 +660,8 @@
return jnp.sum(x, axis=axis, keepdims=keepdims)


def eye(N, M=None, k=0, dtype="float32"):
def eye(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return jnp.eye(N, M=M, k=k, dtype=dtype)


Expand Down
21 changes: 14 additions & 7 deletions keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
return np.max(x, axis=axis, keepdims=keepdims, initial=initial)


def ones(shape, dtype="float32"):
def ones(shape, dtype=None):
dtype = dtype or config.floatx()
return np.ones(shape, dtype=dtype)


def zeros(shape, dtype="float32"):
def zeros(shape, dtype=None):
dtype = dtype or config.floatx()
return np.zeros(shape, dtype=dtype)


Expand Down Expand Up @@ -134,7 +136,6 @@


def array(x, dtype=None):
dtype = dtype or config.floatx()
return np.array(x, dtype=dtype)


Expand Down Expand Up @@ -251,7 +252,8 @@
return np.dot(x, y)


def empty(shape, dtype="float32"):
def empty(shape, dtype=None):
dtype = dtype or config.floatx()

Check warning on line 256 in keras_core/backend/numpy/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/numpy/numpy.py#L256

Added line #L256 was not covered by tests
return np.empty(shape, dtype=dtype)


Expand Down Expand Up @@ -302,7 +304,8 @@
return np.hstack(xs)


def identity(n, dtype="float32"):
def identity(n, dtype=None):
dtype = dtype or config.floatx()
return np.identity(n, dtype=dtype)


Expand Down Expand Up @@ -338,6 +341,7 @@
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
axis = tuple(axis) if isinstance(axis, list) else axis
dtype = dtype or config.floatx()
return np.linspace(
start,
stop,
Expand Down Expand Up @@ -382,6 +386,7 @@


def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
dtype = dtype or config.floatx()
return np.logspace(
start,
stop,
Expand Down Expand Up @@ -556,7 +561,8 @@
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2)


def tri(N, M=None, k=0, dtype="float32"):
def tri(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return np.tri(N, M=M, k=k, dtype=dtype)


Expand Down Expand Up @@ -631,7 +637,8 @@
return np.sum(x, axis=axis, keepdims=keepdims)


def eye(N, M=None, k=0, dtype="float32"):
def eye(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return np.eye(N, M=M, k=k, dtype=dtype)


Expand Down
21 changes: 15 additions & 6 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,13 @@
return tfnp.max(x, axis=axis, keepdims=keepdims)


def ones(shape, dtype="float32"):
def ones(shape, dtype=None):
dtype = dtype or config.floatx()
return tf.ones(shape, dtype=dtype)


def zeros(shape, dtype="float32"):
def zeros(shape, dtype=None):
dtype = dtype or config.floatx()
return tf.zeros(shape, dtype=dtype)


Expand Down Expand Up @@ -403,7 +405,8 @@
return tfnp.dot(x, y)


def empty(shape, dtype="float32"):
def empty(shape, dtype=None):
dtype = dtype or config.floatx()

Check warning on line 409 in keras_core/backend/tensorflow/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/tensorflow/numpy.py#L409

Added line #L409 was not covered by tests
return tfnp.empty(shape, dtype=dtype)


Expand Down Expand Up @@ -434,6 +437,7 @@


def full(shape, fill_value, dtype=None):
dtype = dtype or config.floatx()
return tfnp.full(shape, fill_value, dtype=dtype)


Expand All @@ -453,7 +457,8 @@
return tfnp.hstack(xs)


def identity(n, dtype="float32"):
def identity(n, dtype=None):
dtype = dtype or config.floatx()
return tfnp.identity(n, dtype=dtype)


Expand Down Expand Up @@ -488,6 +493,7 @@
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
dtype = dtype or config.floatx()
return tfnp.linspace(
start,
stop,
Expand Down Expand Up @@ -532,6 +538,7 @@


def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
dtype = dtype or config.floatx()
return tfnp.logspace(
start,
stop,
Expand Down Expand Up @@ -776,7 +783,8 @@
return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)


def tri(N, M=None, k=0, dtype="float32"):
def tri(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return tfnp.tri(N, M=M, k=k, dtype=dtype)


Expand Down Expand Up @@ -863,7 +871,8 @@
return tfnp.sum(x, axis=axis, keepdims=keepdims)


def eye(N, M=None, k=0, dtype="float32"):
def eye(N, M=None, k=0, dtype=None):
dtype = dtype or config.floatx()
return tfnp.eye(N, M=M, k=k, dtype=dtype)


Expand Down
36 changes: 19 additions & 17 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@
return result


def ones(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
def ones(shape, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())
if isinstance(shape, int):
shape = (shape,)
return torch.ones(size=shape, dtype=dtype, device=get_device())


def zeros(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
def zeros(shape, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())
if isinstance(shape, int):
shape = (shape,)
return torch.zeros(size=shape, dtype=dtype, device=get_device())
Expand Down Expand Up @@ -230,7 +230,8 @@


def array(x, dtype=None):
dtype = to_torch_dtype(dtype)
if dtype is not None:
dtype = to_torch_dtype(dtype)
if isinstance(x, torch.Tensor):
return x
return torch.tensor(x, dtype=dtype, device=get_device())
Expand Down Expand Up @@ -386,8 +387,8 @@
return torch.matmul(x, y)


def empty(shape, dtype="float32"):
dtype = to_torch_dtype(dtype)
def empty(shape, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())

Check warning on line 391 in keras_core/backend/torch/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/torch/numpy.py#L391

Added line #L391 was not covered by tests
return torch.empty(size=shape, dtype=dtype, device=get_device())


Expand Down Expand Up @@ -426,7 +427,7 @@


def full(shape, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype or config.floatx())
fill_value = convert_to_tensor(fill_value, dtype=dtype)
if len(fill_value.shape) > 0:
# `torch.full` only supports scala `fill_value`.
Expand Down Expand Up @@ -457,8 +458,8 @@
return torch.hstack(xs)


def identity(n, dtype="float32"):
dtype = to_torch_dtype(dtype)
def identity(n, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())
return torch.eye(n, dtype=dtype)


Expand Down Expand Up @@ -512,7 +513,7 @@
"torch.linspace does not support an `axis` argument. "
f"Received axis={axis}"
)
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype or config.floatx())
if endpoint is False:
stop = stop - ((stop - start) / num)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
Expand Down Expand Up @@ -586,7 +587,7 @@
"torch.logspace does not support an `axis` argument. "
f"Received axis={axis}"
)
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype or config.floatx())
if endpoint is False:
stop = stop - ((stop - start) / num)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
Expand Down Expand Up @@ -738,7 +739,8 @@

def prod(x, axis=None, keepdims=False, dtype=None):
x = convert_to_tensor(x)
dtype = to_torch_dtype(dtype)
if dtype is not None:
dtype = to_torch_dtype(dtype)

Check warning on line 743 in keras_core/backend/torch/numpy.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/torch/numpy.py#L743

Added line #L743 was not covered by tests
if axis is None:
return torch.prod(x, dtype=dtype)
if not isinstance(axis, (list, tuple)):
Expand Down Expand Up @@ -933,8 +935,8 @@
return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1)


def tri(N, M=None, k=0, dtype="float32"):
dtype = to_torch_dtype(dtype)
def tri(N, M=None, k=0, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())
M = M or N
x = torch.ones((N, M), dtype=dtype, device=get_device())
return torch.tril(x, diagonal=k)
Expand Down Expand Up @@ -1037,8 +1039,8 @@
return torch.sum(x)


def eye(N, M=None, k=None, dtype="float32"):
dtype = to_torch_dtype(dtype)
def eye(N, M=None, k=None, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())
M = N if M is None else M
k = 0 if k is None else k
if k == 0:
Expand Down
27 changes: 27 additions & 0 deletions keras_core/layers/rnn/dropout_rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,31 @@ def test_basics(self):
expected_num_non_trainable_weights=0,
expected_num_non_trainable_variables=1,
supports_masking=True,
run_mixed_precision_check=False,
)

# Custom mixed_float16 check
# Never test mixed precision on torch CPU. Torch lacks support.
run_mixed_precision_check = True
if backend.backend() == "torch":
import torch

run_mixed_precision_check = torch.cuda.is_available()
if run_mixed_precision_check:
self.run_layer_test(
layers.RNN,
init_kwargs={
"cell": RNNCellWithDropout(
5, seed=1337, dtype="mixed_float16"
),
"dtype": "mixed_float16",
},
input_shape=(3, 2, 4),
call_kwargs={"training": True},
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_non_trainable_variables=1,
supports_masking=True,
run_mixed_precision_check=False,
)
Loading