Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply backend.result_type to absolute, argmax, argmin, argsort, ceil, clip and dot #18548

Merged
merged 6 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion keras/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def broadcast_to(x, shape):


def ceil(x):
return jnp.ceil(x)
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return cast(jnp.ceil(x), dtype)


def clip(x, x_min, x_max):
Expand Down
23 changes: 18 additions & 5 deletions keras/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ def arctanh(x):

def argmax(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argmax(x, axis=axis)
return np.argmax(x, axis=axis).astype("int32")


def argmin(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argmin(x, axis=axis)
return np.argmin(x, axis=axis).astype("int32")


def argsort(x, axis=-1):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argsort(x, axis=axis)
return np.argsort(x, axis=axis).astype("int32")


def array(x, dtype=None):
Expand Down Expand Up @@ -201,11 +201,19 @@ def broadcast_to(x, shape):


def ceil(x):
return np.ceil(x)
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return np.ceil(x).astype(dtype)


def clip(x, x_min, x_max):
return np.clip(x, x_min, x_max)
dtype = standardize_dtype(x.dtype)
if dtype == "bool":
dtype = "int64"
return np.clip(x, x_min, x_max).astype(dtype)


def concatenate(xs, axis=0):
Expand Down Expand Up @@ -280,6 +288,11 @@ def digitize(x, bins):


def dot(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
dtype = dtypes.result_type(x.dtype, y.dtype)
x = x.astype(dtype)
y = y.astype(dtype)
return np.dot(x, y)


Expand Down
31 changes: 25 additions & 6 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def zeros(shape, dtype=None):


def absolute(x):
# uintx and bool are always non-negative
dtype = standardize_dtype(x.dtype)
if "uint" in dtype or dtype == "bool":
return x
return tfnp.absolute(x)


Expand Down Expand Up @@ -341,15 +345,15 @@ def arctanh(x):


def argmax(x, axis=None):
return tfnp.argmax(x, axis=axis)
return tf.cast(tfnp.argmax(x, axis=axis), dtype="int32")


def argmin(x, axis=None):
return tfnp.argmin(x, axis=axis)
return tf.cast(tfnp.argmin(x, axis=axis), dtype="int32")


def argsort(x, axis=-1):
return tfnp.argsort(x, axis=axis)
return tf.cast(tfnp.argsort(x, axis=axis), dtype="int32")


def array(x, dtype=None):
Expand All @@ -370,11 +374,19 @@ def broadcast_to(x, shape):


def ceil(x):
return tfnp.ceil(x)
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return tf.cast(tfnp.ceil(x), dtype=dtype)


def clip(x, x_min, x_max):
return tfnp.clip(x, x_min, x_max)
dtype = standardize_dtype(x.dtype)
if dtype == "bool":
dtype = "int64"
return tf.cast(tfnp.clip(x, x_min, x_max), dtype=dtype)


def concatenate(xs, axis=0):
Expand Down Expand Up @@ -463,7 +475,14 @@ def digitize(x, bins):


def dot(x, y):
return tfnp.dot(x, y)
x = convert_to_tensor(x)
y = convert_to_tensor(y)
result_dtype = dtypes.result_type(x.dtype, y.dtype)
# GPU only supports float types
compute_dtype = dtypes.result_type(result_dtype, float)
x = tf.cast(x, compute_dtype)
y = tf.cast(y, compute_dtype)
return tf.cast(tfnp.dot(x, y), dtype=result_dtype)


def empty(shape, dtype=None):
Expand Down
29 changes: 22 additions & 7 deletions keras/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def absolute(x):

def abs(x):
x = convert_to_tensor(x)
# bool are always non-negative
if standardize_dtype(x.dtype) == "bool":
return x
return torch.abs(x)


Expand Down Expand Up @@ -237,20 +240,20 @@ def arctanh(x):

def argmax(x, axis=None):
x = convert_to_tensor(x)
return torch.argmax(x, dim=axis)
return cast(torch.argmax(x, dim=axis), dtype="int32")


def argmin(x, axis=None):
x = convert_to_tensor(x)
return torch.argmin(x, dim=axis)
return cast(torch.argmin(x, dim=axis), dtype="int32")


def argsort(x, axis=-1):
x = convert_to_tensor(x)
if axis is None:
axis = -1
x = x.reshape(-1)
return torch.argsort(x, dim=axis, stable=True)
return cast(torch.argsort(x, dim=axis, stable=True), dtype="int32")


def array(x, dtype=None):
Expand Down Expand Up @@ -311,13 +314,21 @@ def broadcast_to(x, shape):

def ceil(x):
x = convert_to_tensor(x)
return torch.ceil(x)
if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return cast(torch.ceil(x), dtype=dtype)


def clip(x, x_min, x_max):
x = convert_to_tensor(x)
x_min, x_max = convert_to_tensor(x_min), convert_to_tensor(x_max)
return torch.clip(x, min=x_min, max=x_max)
x_min = convert_to_tensor(x_min)
x_max = convert_to_tensor(x_max)
dtype = standardize_dtype(x.dtype)
if dtype == "bool":
dtype = "int64"
return cast(torch.clip(x, min=x_min, max=x_max), dtype=dtype)


def concatenate(xs, axis=0):
Expand Down Expand Up @@ -409,7 +420,11 @@ def digitize(x, bins):


def dot(x, y):
x, y = convert_to_tensor(x), convert_to_tensor(y)
x = convert_to_tensor(x)
y = convert_to_tensor(y)
result_dtype = dtypes.result_type(x.dtype, y.dtype)
x = cast(x, result_dtype)
y = cast(y, result_dtype)
if x.ndim == 0 or y.ndim == 0:
return torch.multiply(x, y)
return torch.matmul(x, y)
Expand Down
23 changes: 17 additions & 6 deletions keras/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,11 @@ def call(self, x):
return backend.numpy.ceil(x)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
if backend.standardize_dtype(x.dtype) == "int64":
dtype = backend.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return KerasTensor(x.shape, dtype=dtype)


@keras_export(["keras.ops.ceil", "keras.ops.numpy.ceil"])
Expand All @@ -1357,7 +1361,7 @@ def ceil(x):
x: Input tensor.

Returns:
The ceiling of each element in `x`.
The ceiling of each element in `x`, with float dtype.
"""
if any_symbolic_tensors((x,)):
return Ceil().symbolic_call(x)
Expand All @@ -1374,7 +1378,10 @@ def call(self, x):
return backend.numpy.clip(x, self.x_min, self.x_max)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
dtype = backend.standardize_dtype(x.dtype)
if dtype == "bool":
dtype = "int64"
return KerasTensor(x.shape, dtype=dtype)


@keras_export(["keras.ops.clip", "keras.ops.numpy.clip"])
Expand Down Expand Up @@ -2001,18 +2008,22 @@ def call(self, x1, x2):
def compute_output_spec(self, x1, x2):
x1_shape = list(getattr(x1, "shape", []))
x2_shape = list(getattr(x2, "shape", []))
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
if x1_shape == [] or x2_shape == []:
return multiply(x1, x2)
if len(x1_shape) == 1 and len(x2_shape) == 1:
return KerasTensor([], dtype=x1.dtype)
return KerasTensor([], dtype=dtype)
if len(x2_shape) == 1:
if x1_shape[-1] != x2_shape[0]:
raise ValueError(
"Shape must match on the last axis of `x1` and `x2` when "
"`x1` is N-d array while `x2` is 1-D, but receive shape "
f"`x1.shape={x1.shape}` and x2.shape=`{x2.shape}`."
)
return KerasTensor(x1_shape[:-1], dtype=x1.dtype)
return KerasTensor(x1_shape[:-1], dtype=dtype)

if (
x1_shape[-1] is None
Expand All @@ -2021,7 +2032,7 @@ def compute_output_spec(self, x1, x2):
):
del x1_shape[-1]
del x2_shape[-2]
return KerasTensor(x1_shape + x2_shape, dtype=x1.dtype)
return KerasTensor(x1_shape + x2_shape, dtype=dtype)

raise ValueError(
"Shape must match on the last axis of `x1` and second last "
Expand Down
Loading
Loading