diff --git a/conftest.py b/conftest.py index d139e887530..2fcc51ec006 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,9 @@ +import os + +# When using jax.experimental.enable_x64 in unit test, we want to keep the +# default dtype with 32 bits, aligning it with Keras's default. +os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32" + try: # When using torch and tensorflow, torch needs to be imported first, # otherwise it will segfault upon import. This should force the torch @@ -6,9 +12,9 @@ except ImportError: pass -import pytest +import pytest # noqa: E402 -from keras.backend import backend +from keras.backend import backend # noqa: E402 def pytest_configure(config): diff --git a/keras/backend/__init__.py b/keras/backend/__init__.py index c2a155baa09..80f3fee09e6 100644 --- a/keras/backend/__init__.py +++ b/keras/backend/__init__.py @@ -6,6 +6,7 @@ # upon import. import torch +from keras.backend.common.dtypes import result_type from keras.backend.common.keras_tensor import KerasTensor from keras.backend.common.keras_tensor import any_symbolic_tensors from keras.backend.common.keras_tensor import is_keras_tensor diff --git a/keras/backend/common/__init__.py b/keras/backend/common/__init__.py index 858e41cc445..a29e86ce15a 100644 --- a/keras/backend/common/__init__.py +++ b/keras/backend/common/__init__.py @@ -1,4 +1,5 @@ from keras.backend.common import backend_utils +from keras.backend.common.dtypes import result_type from keras.backend.common.variables import AutocastScope from keras.backend.common.variables import KerasVariable from keras.backend.common.variables import get_autocast_scope diff --git a/keras/backend/common/dtypes.py b/keras/backend/common/dtypes.py new file mode 100644 index 00000000000..b5ee2d37b1d --- /dev/null +++ b/keras/backend/common/dtypes.py @@ -0,0 +1,277 @@ +import functools + +from keras import backend +from keras.api_export import keras_export +from keras.backend.common.variables import ALLOWED_DTYPES +from keras.backend.common.variables import standardize_dtype + +""" +We adapted the type promotion lattice from JAX. Ref: +https://github.com/google/jax/blob/main/jax/_src/dtypes.py +""" + +BOOL_TYPES = ["bool"] +INT_TYPES = [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", +] +FLOAT_TYPES = ["bfloat16", "float16", "float32", "float64"] +WEAK_TYPES = ["int", "float"] + + +def _type_promotion_lattice(): + """ + Return the type promotion lattice in the form of a DAG. + This DAG maps each type to its immediately higher type on the lattice. + """ + (b1,) = BOOL_TYPES + (u1, u2, u4, u8, i1, i2, i4, i8) = INT_TYPES + bf, f2, f4, f8 = FLOAT_TYPES + i_, f_ = WEAK_TYPES + out = { + b1: [i_], + u1: [i2, u2], + u2: [i4, u4], + u4: [i8, u8], + u8: [f_], + i_: [u1, i1], + i1: [i2], + i2: [i4], + i4: [i8], + i8: [f_], + f_: [bf, f2], + bf: [f4], + f2: [f4], + f4: [f8], + f8: [], + } + return out + + +def _make_lattice_upper_bounds(): + lattice = _type_promotion_lattice() + upper_bounds = {node: {node} for node in lattice} + for n in lattice: + while True: + new_upper_bounds = set().union( + *(lattice[b] for b in upper_bounds[n]) + ) + if n in new_upper_bounds: + raise ValueError( + f"cycle detected in type promotion lattice for node {n}" + ) + if new_upper_bounds.issubset(upper_bounds[n]): + break + upper_bounds[n] |= new_upper_bounds + return upper_bounds + + +LATTICE_UPPER_BOUNDS = _make_lattice_upper_bounds() + + +@functools.lru_cache(512) +def _least_upper_bound(*nodes): + """Compute the least upper bound of a set of nodes. + + Args: + nodes: sequence of entries from dtypes + weak_types + + Returns: + The type representing the least upper bound of the input nodes on the + promotion lattice. + """ + # This function computes the least upper bound of a set of nodes N within a + # partially ordered set defined by the lattice generated above. + # Given a partially ordered set S, let the set of upper bounds of n ∈ S be + # UB(n) ≡ {m ∈ S | n ≤ m} + # Further, for a set of nodes N ⊆ S, let the set of common upper bounds be + # given by + # CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)} + # Then the least upper bound of N is defined as + # LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d} + # The definition of an upper bound implies that + # c ≤ d if and only if d ∈ UB(c), + # so the LUB can be expressed: + # LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)} + # or, equivalently: + # LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)} + # By definition, LUB(N) has a cardinality of 1 for a partially ordered set. + # Note a potential algorithmic shortcut: from the definition of CUB(N), + # we have + # ∀ c ∈ N: CUB(N) ⊆ UB(c) + # So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N). + N = set(nodes) + UB = LATTICE_UPPER_BOUNDS + try: + bounds = [UB[n] for n in N] + except KeyError: + dtype = next(n for n in N if n not in UB) + raise ValueError( + f"{dtype=} is not a valid dtype for Keras type promotion." + ) + CUB = set.intersection(*bounds) + LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])} + if len(LUB) == 1: + return LUB.pop() + elif len(LUB) == 0: + msg = ( + f"Input dtypes {tuple(str(n) for n in nodes)} have no available " + "implicit dtype promotion path. Try explicitly casting inputs to " + "the desired output type." + ) + raise ValueError(msg) + else: + # If we get here, it means the lattice is ill-formed. + raise ValueError( + f"Internal Type Promotion error: {nodes} do not have a unique " + f"least upper bound on the specified lattice; options are {LUB}. " + "This is an unexpected error in Keras's internal logic; " + "please report it to the maintainers." + ) + + +def _dtype_and_weaktype(value): + """Return a (dtype, weak_type) tuple for the given input.""" + is_weak_type = False + if value is int or value is float: + # Note that we can't use `value in [int, float]` because the dtype + # might be equal to python scalar types. + # e.g, tf.float32 == float is True + is_weak_type = True + return standardize_dtype(value), is_weak_type + + +@functools.lru_cache(maxsize=None) +def _respect_weak_type(dtype, weak_type): + """Return the weak dtype of `dtype` if `weak_type==True`.""" + if weak_type: + if dtype == "bool": + return dtype + elif "float" in dtype: + return "float" + elif "int" in dtype: + return "int" + else: + raise ValueError( + "Invalid value for argument `dtype`. Expected one of " + f"{ALLOWED_DTYPES}. Received: dtype={dtype}" + ) + return dtype + + +@functools.lru_cache(maxsize=None) +def _resolve_weak_type(dtype, precision="32"): + """Resolve weak type by the precision of `backend.floatx()`.""" + extended_allowed_dtypes = ALLOWED_DTYPES.union(WEAK_TYPES) + if dtype not in extended_allowed_dtypes: + raise ValueError( + "Invalid value for argument `dtype`. Expected one of " + f"{extended_allowed_dtypes}. Received: dtype={dtype}" + ) + if precision not in ["16", "32", "64"]: + raise ValueError( + f"Invalid value for argument `precision`. Expected one of " + f"('16', '32', '64'). Received: precision={precision}" + ) + if dtype == "bfloat16": # special case for bfloat16 + dtype_indicator = "f" + else: + dtype_indicator = dtype[:1] + + if dtype_indicator == "b": + return "bool" + elif dtype_indicator == "i": + return "int" + precision + elif dtype_indicator == "u": + return "uint" + precision + else: + return "float" + precision + + +BIT64_TO_BIT16_DTYPE = { + "int32": "int16", + "int64": "int16", + "uint32": "uint16", + "uint64": "uint16", + "float32": "float16", + "float64": "float16", +} +BIT64_TO_BIT32_DTYPE = { + "int64": "int32", + "uint64": "uint32", + "float64": "float32", +} + + +def _lattice_result_type(*args): + dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args)) + if len(dtypes) == 1: + out_dtype = dtypes[0] + out_weak_type = weak_types[0] + elif len(set(dtypes)) == 1 and not all(weak_types): + # Trivial promotion case. This allows extended dtypes through. + out_dtype = dtypes[0] + out_weak_type = False + elif all(weak_types): + # If all inputs are weakly typed, we compute the bound of the + # strongly-typed counterparts and apply the weak type at the end. This + # avoids returning the incorrect result with non-canonical weak types + # (e.g. weak int16). + out_dtype = _least_upper_bound( + *{_respect_weak_type(d, False) for d in dtypes} + ) + out_weak_type = True + else: + out_dtype = _least_upper_bound( + *{_respect_weak_type(d, w) for d, w in zip(dtypes, weak_types)} + ) + out_weak_type = any(out_dtype is t for t in WEAK_TYPES) + + out_weak_type = (out_dtype != "bool") and out_weak_type + precision = backend.floatx()[-2:] + if out_weak_type: + out_dtype = _resolve_weak_type(out_dtype, precision=precision) + return out_dtype + + +@keras_export("keras.backend.result_type") +def result_type(*dtypes): + """Returns the type from applying the Keras type promotion rules. + + In general, each argument is first parsed by `backend.standardize_dtype`, + and the resulting dtype is determined by the least upper bound of the type + promotion lattice. + + Note: This function attempts to match the result of `jnp.result_type`. + + Args: + dtypes: Input dtypes. + + Returns: + The result dtype. + + Examples: + + >>> x = keras.ops.ones((1,), dtype="bfloat16") + >>> keras.backend.result_type(x.dtype, int) + "bfloat16" + + >>> x = keras.ops.ones((1,), dtype="int32") + >>> y = keras.ops.ones((1,), dtype="float32") + >>> keras.backend.result_type(x.dtype, y.dtype) + "float32" + """ + if len(dtypes) == 0: + raise ValueError( + "Invalid `dtypes`. At least one dtype is required. " + f"Received: dtypes={dtypes}" + ) + return _lattice_result_type( + *(backend.floatx() if arg is None else arg for arg in dtypes), + ) diff --git a/keras/backend/common/dtypes_test.py b/keras/backend/common/dtypes_test.py new file mode 100644 index 00000000000..f3b791ccde8 --- /dev/null +++ b/keras/backend/common/dtypes_test.py @@ -0,0 +1,64 @@ +from absl.testing import parameterized + +from keras import backend +from keras import ops +from keras.backend.common.variables import ALLOWED_DTYPES +from keras.backend.torch.core import to_torch_dtype +from keras.testing import test_case + + +class DtypesTest(test_case.TestCase, parameterized.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + if backend.backend() == "torch": + # TODO: torch doesn't support uint64. + ALL_DTYPES = [ + str(to_torch_dtype(x)).split(".")[-1] + for x in ALLOWED_DTYPES + if x not in ["string", "uint64"] + ] + [None] + else: + ALL_DTYPES = [x for x in ALLOWED_DTYPES if x != "string"] + [None] + + def setUp(self): + from jax.experimental import enable_x64 + + self.jax_enable_x64 = enable_x64() + self.jax_enable_x64.__enter__() + return super().setUp() + + def tearDown(self) -> None: + self.jax_enable_x64.__exit__(None, None, None) + return super().tearDown() + + @parameterized.product(dtype1=ALL_DTYPES, dtype2=[bool, int, float]) + def test_result_type_with_python_scalar_types(self, dtype1, dtype2): + import jax.numpy as jnp + + out = backend.result_type(dtype1, dtype2) + expected = jnp.result_type(dtype1, dtype2).name + self.assertEqual(out, expected) + + @parameterized.product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + def test_result_type_with_tensor(self, dtype1, dtype2): + import jax.numpy as jnp + + x1 = ops.ones((1,), dtype=dtype1) + x2 = ops.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + + out = backend.result_type(x1.dtype, x2.dtype) + expected = jnp.result_type(x1_jax, x2_jax).name + self.assertEqual(out, expected) + + def test_result_type_with_none(self): + import jax.numpy as jnp + + self.assertEqual(backend.result_type(None), jnp.result_type(None).name) + + def test_result_type_invalid_dtypes(self): + with self.assertRaisesRegexp( + ValueError, "Invalid `dtypes`. At least one dtype is required." + ): + backend.result_type() diff --git a/keras/backend/jax/numpy.py b/keras/backend/jax/numpy.py index 75741a70cf4..715213db2c7 100644 --- a/keras/backend/jax/numpy.py +++ b/keras/backend/jax/numpy.py @@ -1,6 +1,7 @@ import jax.numpy as jnp from keras.backend import config +from keras.backend.common import dtypes from keras.backend.jax.core import cast from keras.backend.jax.core import convert_to_tensor @@ -70,11 +71,13 @@ def max(x, axis=None, keepdims=False, initial=None): return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial) -def ones(shape, dtype="float32"): +def ones(shape, dtype=None): + dtype = dtype or config.floatx() return jnp.ones(shape, dtype=dtype) -def zeros(shape, dtype="float32"): +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() return jnp.zeros(shape, dtype=dtype) @@ -114,12 +117,13 @@ def append( def arange(start, stop=None, step=1, dtype=None): if dtype is None: - if hasattr(start, "dtype"): - dtype = start.dtype - elif isinstance(start, int): - dtype = "int32" - else: - dtype = config.floatx() + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(step, "dtype", type(step)), + ] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + dtype = dtypes.result_type(*dtypes_to_resolve) return jnp.arange(start, stop, step=step, dtype=dtype) @@ -253,7 +257,8 @@ def dot(x, y): return jnp.dot(x, y) -def empty(shape, dtype="float32"): +def empty(shape, dtype=None): + dtype = dtype or config.floatx() return jnp.empty(shape, dtype=dtype) @@ -307,7 +312,8 @@ def hstack(xs): return jnp.hstack(xs) -def identity(n, dtype="float32"): +def identity(n, dtype=None): + dtype = dtype or config.floatx() return jnp.identity(n, dtype=dtype) @@ -573,7 +579,8 @@ def trace(x, offset=0, axis1=0, axis2=1): return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2) -def tri(N, M=None, k=0, dtype="float32"): +def tri(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() return jnp.tri(N, M=M, k=k, dtype=dtype) @@ -652,7 +659,8 @@ def sum(x, axis=None, keepdims=False): return jnp.sum(x, axis=axis, keepdims=keepdims) -def eye(N, M=None, k=0, dtype="float32"): +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() return jnp.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras/backend/numpy/numpy.py b/keras/backend/numpy/numpy.py index 32ae80f0166..c34accebc0d 100644 --- a/keras/backend/numpy/numpy.py +++ b/keras/backend/numpy/numpy.py @@ -2,9 +2,16 @@ from keras.backend import config from keras.backend import standardize_dtype +from keras.backend.common import dtypes +from keras.backend.numpy.core import convert_to_tensor def add(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) return np.add(x1, x2) @@ -34,11 +41,13 @@ def max(x, axis=None, keepdims=False, initial=None): return np.max(x, axis=axis, keepdims=keepdims, initial=initial) -def ones(shape, dtype="float32"): +def ones(shape, dtype=None): + dtype = dtype or config.floatx() return np.ones(shape, dtype=dtype) -def zeros(shape, dtype="float32"): +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() return np.zeros(shape, dtype=dtype) @@ -81,12 +90,13 @@ def append( def arange(start, stop=None, step=None, dtype=None): if dtype is None: - if hasattr(start, "dtype"): - dtype = start.dtype - elif isinstance(start, int): - dtype = "int32" - else: - dtype = config.floatx() + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(step, "dtype", type(step)), + ] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + dtype = dtypes.result_type(*dtypes_to_resolve) return np.arange(start, stop, step=step, dtype=dtype) @@ -134,7 +144,6 @@ def argsort(x, axis=-1): def array(x, dtype=None): - dtype = dtype or config.floatx() return np.array(x, dtype=dtype) @@ -251,7 +260,8 @@ def dot(x, y): return np.dot(x, y) -def empty(shape, dtype="float32"): +def empty(shape, dtype=None): + dtype = dtype or config.floatx() return np.empty(shape, dtype=dtype) @@ -302,7 +312,8 @@ def hstack(xs): return np.hstack(xs) -def identity(n, dtype="float32"): +def identity(n, dtype=None): + dtype = dtype or config.floatx() return np.identity(n, dtype=dtype) @@ -556,7 +567,8 @@ def trace(x, offset=0, axis1=0, axis2=1): return np.trace(x, offset=offset, axis1=axis1, axis2=axis2) -def tri(N, M=None, k=0, dtype="float32"): +def tri(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() return np.tri(N, M=M, k=k, dtype=dtype) @@ -604,10 +616,13 @@ def square(x): def sqrt(x): - dtype = None - if hasattr(x, "dtype"): - if standardize_dtype(x.dtype).startswith("int"): - dtype = config.floatx() + x = convert_to_tensor(x) + # upcast to float64 for int64 which matches JAX's behavior + dtype = ( + "float64" + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) return np.sqrt(x, dtype=dtype) @@ -631,7 +646,8 @@ def sum(x, axis=None, keepdims=False): return np.sum(x, axis=axis, keepdims=keepdims) -def eye(N, M=None, k=0, dtype="float32"): +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() return np.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index e493f15036c..7ed2914773f 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -8,10 +8,17 @@ from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops from keras.backend import config +from keras.backend import standardize_dtype +from keras.backend.common import dtypes from keras.backend.tensorflow.core import convert_to_tensor def add(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) if isinstance(x1, tf.SparseTensor) or isinstance(x2, tf.SparseTensor): return tf.sparse.add(x1, x2) return tfnp.add(x1, x2) @@ -196,11 +203,13 @@ def max(x, axis=None, keepdims=False, initial=None): return tfnp.max(x, axis=axis, keepdims=keepdims) -def ones(shape, dtype="float32"): +def ones(shape, dtype=None): + dtype = dtype or config.floatx() return tf.ones(shape, dtype=dtype) -def zeros(shape, dtype="float32"): +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() return tf.zeros(shape, dtype=dtype) @@ -240,12 +249,13 @@ def arange(start, stop=None, step=1, dtype=None): # tfnp.arange has trouble with dynamic Tensors in compiled function. # tf.range does not. if dtype is None: - if hasattr(start, "dtype"): - dtype = start.dtype - elif isinstance(start, int): - dtype = "int32" - else: - dtype = config.floatx() + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(step, "dtype", type(step)), + ] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + dtype = dtypes.result_type(*dtypes_to_resolve) return tf.range(start, stop, delta=step, dtype=dtype) @@ -403,7 +413,8 @@ def dot(x, y): return tfnp.dot(x, y) -def empty(shape, dtype="float32"): +def empty(shape, dtype=None): + dtype = dtype or config.floatx() return tfnp.empty(shape, dtype=dtype) @@ -453,7 +464,8 @@ def hstack(xs): return tfnp.hstack(xs) -def identity(n, dtype="float32"): +def identity(n, dtype=None): + dtype = dtype or config.floatx() return tfnp.identity(n, dtype=dtype) @@ -776,7 +788,8 @@ def trace(x, offset=0, axis1=0, axis2=1): return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2) -def tri(N, M=None, k=0, dtype="float32"): +def tri(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() return tfnp.tri(N, M=M, k=k, dtype=dtype) @@ -822,9 +835,16 @@ def square(x): def sqrt(x): x = convert_to_tensor(x) - if tf.as_dtype(x.dtype).is_integer: - x = tf.cast(x, dtype=config.floatx()) - return tfnp.sqrt(x) + # upcast to float64 for int64 which matches JAX's behavior + dtype = ( + "float64" + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + # TODO: Use tfnp.sqrt. Currently, tfnp.sqrt will aggressively upcast to + # float64 if the input is bfloat16. This behavior mismatches with JAX. + return tf.sqrt(x) def squeeze(x, axis=None): @@ -863,7 +883,8 @@ def sum(x, axis=None, keepdims=False): return tfnp.sum(x, axis=axis, keepdims=keepdims) -def eye(N, M=None, k=0, dtype="float32"): +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() return tfnp.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index cf2322b60c8..149303d5d54 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -2,6 +2,7 @@ import torch from keras.backend import config +from keras.backend.common import dtypes from keras.backend.torch.core import cast from keras.backend.torch.core import convert_to_tensor from keras.backend.torch.core import get_device @@ -17,7 +18,8 @@ def add(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) return torch.add(x1, x2) @@ -76,15 +78,15 @@ def max(x, axis=None, keepdims=False, initial=None): return result -def ones(shape, dtype="float32"): - dtype = to_torch_dtype(dtype) +def ones(shape, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) if isinstance(shape, int): shape = (shape,) return torch.ones(size=shape, dtype=dtype, device=get_device()) -def zeros(shape, dtype="float32"): - dtype = to_torch_dtype(dtype) +def zeros(shape, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) if isinstance(shape, int): shape = (shape,) return torch.zeros(size=shape, dtype=dtype, device=get_device()) @@ -162,12 +164,13 @@ def append( def arange(start, stop=None, step=1, dtype=None): if dtype is None: - if hasattr(start, "dtype"): - dtype = start.dtype - elif isinstance(start, int): - dtype = "int32" - else: - dtype = config.floatx() + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(step, "dtype", type(step)), + ] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + dtype = dtypes.result_type(*dtypes_to_resolve) dtype = to_torch_dtype(dtype) if stop is None: return torch.arange(end=start, dtype=dtype, device=get_device()) @@ -386,8 +389,8 @@ def dot(x, y): return torch.matmul(x, y) -def empty(shape, dtype="float32"): - dtype = to_torch_dtype(dtype) +def empty(shape, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) return torch.empty(size=shape, dtype=dtype, device=get_device()) @@ -457,9 +460,9 @@ def hstack(xs): return torch.hstack(xs) -def identity(n, dtype="float32"): - dtype = to_torch_dtype(dtype) - return torch.eye(n, dtype=dtype) +def identity(n, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + return torch.eye(n, dtype=dtype, device=get_device()) def imag(x): @@ -933,8 +936,8 @@ def trace(x, offset=None, axis1=None, axis2=None): return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1) -def tri(N, M=None, k=0, dtype="float32"): - dtype = to_torch_dtype(dtype) +def tri(N, M=None, k=0, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) M = M or N x = torch.ones((N, M), dtype=dtype, device=get_device()) return torch.tril(x, diagonal=k) @@ -997,6 +1000,9 @@ def square(x): def sqrt(x): x = convert_to_tensor(x) + # upcast to float64 for int64 which matches JAX's behavior + if x.dtype == torch.int64: + x = cast(x, "float64") return torch.sqrt(x) @@ -1037,8 +1043,8 @@ def sum(x, axis=None, keepdims=False): return torch.sum(x) -def eye(N, M=None, k=None, dtype="float32"): - dtype = to_torch_dtype(dtype) +def eye(N, M=None, k=None, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) M = N if M is None else M k = 0 if k is None else k if k == 0: diff --git a/keras/layers/rnn/dropout_rnn_cell_test.py b/keras/layers/rnn/dropout_rnn_cell_test.py index 927fd627c62..ab467f77e04 100644 --- a/keras/layers/rnn/dropout_rnn_cell_test.py +++ b/keras/layers/rnn/dropout_rnn_cell_test.py @@ -64,4 +64,30 @@ def test_basics(self): expected_num_non_trainable_weights=0, expected_num_non_trainable_variables=1, supports_masking=True, + run_mixed_precision_check=False, ) + + # manually set dtype to mixed_float16 to run mixed precision check + run_mixed_precision_check = True + if backend.backend() == "torch": + import torch + + run_mixed_precision_check = torch.cuda.is_available() + if run_mixed_precision_check: + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": RNNCellWithDropout( + 5, seed=1337, dtype="mixed_float16" + ), + "dtype": "mixed_float16", + }, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_non_trainable_variables=1, + supports_masking=True, + run_mixed_precision_check=False, + ) diff --git a/keras/ops/numpy.py b/keras/ops/numpy.py index d1c0e994c08..88afdce912e 100644 --- a/keras/ops/numpy.py +++ b/keras/ops/numpy.py @@ -147,6 +147,7 @@ from keras.api_export import keras_export from keras.backend import KerasTensor from keras.backend import any_symbolic_tensors +from keras.backend.common import dtypes from keras.ops import operation_utils from keras.ops.operation import Operation from keras.ops.operation_utils import reduce_shape @@ -303,10 +304,16 @@ def compute_output_spec(self, x1, x2): x1_shape = getattr(x1, "shape", []) x2_shape = getattr(x2, "shape", []) output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) x1_sparse = getattr(x1, "sparse", True) x2_sparse = getattr(x2, "sparse", True) output_sparse = x1_sparse and x2_sparse - return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse) + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) @keras_export(["keras.ops.add", "keras.ops.numpy.add"]) @@ -662,7 +669,15 @@ def call(self, start, stop=None, step=1, dtype=None): def compute_output_spec(self, start, stop=None, step=1, dtype=None): if stop is None: start, stop = 0, start - output_shape = [np.ceil((stop - start) / step).astype(int)] + output_shape = [int(np.ceil((stop - start) / step))] + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(step, "dtype", type(step)), + ] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + dtype = dtypes.result_type(*dtypes_to_resolve) return KerasTensor(output_shape, dtype=dtype) @@ -2302,15 +2317,16 @@ def einsum(subscripts, *operands): class Empty(Operation): - def call(self, shape, dtype="float32"): + def call(self, shape, dtype=None): return backend.numpy.empty(shape, dtype=dtype) - def compute_output_spec(self, shape, dtype="float32"): + def compute_output_spec(self, shape, dtype=None): + dtype = dtype or backend.floatx() return KerasTensor(shape, dtype=dtype) @keras_export(["keras.ops.empty", "keras.ops.numpy.empty"]) -def empty(shape, dtype="float32"): +def empty(shape, dtype=None): """Return a tensor of given shape and type filled with uninitialized data. Args: @@ -2728,15 +2744,16 @@ def hstack(xs): class Identity(Operation): - def call(self, n, dtype="float32"): + def call(self, n, dtype=None): return backend.numpy.identity(n, dtype=dtype) - def compute_output_spec(self, n, dtype="float32"): + def compute_output_spec(self, n, dtype=None): + dtype = dtype or backend.floatx() return KerasTensor([n, n], dtype=dtype) @keras_export(["keras.ops.identity", "keras.ops.numpy.identity"]) -def identity(n, dtype="float32"): +def identity(n, dtype=None): """Return the identity tensor. The identity tensor is a square tensor with ones on the main diagonal and @@ -4879,17 +4896,18 @@ def trace(x, offset=0, axis1=0, axis2=1): class Tri(Operation): - def call(self, N, M=None, k=0, dtype="float32"): + def call(self, N, M=None, k=0, dtype=None): return backend.numpy.tri(N, M=M, k=k, dtype=dtype) - def compute_output_spec(self, N, M=None, k=0, dtype="float32"): + def compute_output_spec(self, N, M=None, k=0, dtype=None): if M is None: M = N + dtype = dtype or backend.floatx() return KerasTensor((N, M), dtype=dtype) @keras_export(["keras.ops.tri", "keras.ops.numpy.tri"]) -def tri(N, M=None, k=0, dtype="float32"): +def tri(N, M=None, k=0, dtype=None): """Return a tensor with ones at and below a diagonal and zeros elsewhere. Args: @@ -5269,7 +5287,12 @@ def call(self, x): return backend.numpy.sqrt(x) def compute_output_spec(self, x): - return KerasTensor(x.shape, dtype=x.dtype) + dtype = ( + "float64" + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return KerasTensor(x.shape, dtype=dtype) @keras_export(["keras.ops.sqrt", "keras.ops.numpy.sqrt"]) @@ -5481,15 +5504,16 @@ def sum(x, axis=None, keepdims=False): class Zeros(Operation): - def call(self, shape, dtype="float32"): + def call(self, shape, dtype=None): return backend.numpy.zeros(shape, dtype=dtype) - def compute_output_spec(self, shape, dtype="float32"): + def compute_output_spec(self, shape, dtype=None): + dtype = dtype or backend.floatx() return KerasTensor(shape, dtype=dtype) @keras_export(["keras.ops.zeros", "keras.ops.numpy.zeros"]) -def zeros(shape, dtype="float32"): +def zeros(shape, dtype=None): """Return a new tensor of given shape and type, filled with zeros. Args: @@ -5503,15 +5527,16 @@ def zeros(shape, dtype="float32"): class Ones(Operation): - def call(self, shape, dtype="float32"): + def call(self, shape, dtype=None): return backend.numpy.ones(shape, dtype=dtype) - def compute_output_spec(self, shape, dtype="float32"): + def compute_output_spec(self, shape, dtype=None): + dtype = dtype or backend.floatx() return KerasTensor(shape, dtype=dtype) @keras_export(["keras.ops.ones", "keras.ops.numpy.ones"]) -def ones(shape, dtype="float32"): +def ones(shape, dtype=None): """Return a new tensor of given shape and type, filled with ones. Args: @@ -5525,17 +5550,18 @@ def ones(shape, dtype="float32"): class Eye(Operation): - def call(self, N, M=None, k=0, dtype="float32"): + def call(self, N, M=None, k=0, dtype=None): return backend.numpy.eye(N, M=M, k=k, dtype=dtype) - def compute_output_spec(self, N, M=None, k=0, dtype="float32"): + def compute_output_spec(self, N, M=None, k=0, dtype=None): if M is None: M = N + dtype = dtype or backend.floatx() return KerasTensor((N, M), dtype=dtype) @keras_export(["keras.ops.eye", "keras.ops.numpy.eye"]) -def eye(N, M=None, k=0, dtype="float32"): +def eye(N, M=None, k=0, dtype=None): """Return a 2-D tensor with ones on the diagonal and zeros elsewhere. Args: diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index c917a9dedc0..9c9b59f9798 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -7,6 +7,8 @@ from keras import testing from keras.backend.common import standardize_dtype from keras.backend.common.keras_tensor import KerasTensor +from keras.backend.common.variables import ALLOWED_DTYPES +from keras.backend.torch.core import to_torch_dtype from keras.ops import numpy as knp # TODO: remove reliance on this (or alternatively, turn it on by default). @@ -3872,3 +3874,211 @@ def test_tri(self): self.assertAllClose(knp.Tri()(3), np.tri(3)) self.assertAllClose(knp.Tri()(3, 4), np.tri(3, 4)) self.assertAllClose(knp.Tri()(3, 4, 1), np.tri(3, 4, 1)) + + +class NumpyDtypeTest(testing.TestCase, parameterized.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + if backend.backend() == "torch": + # TODO: torch doesn't support uint64. + ALL_DTYPES = [ + str(to_torch_dtype(x)).split(".")[-1] + for x in ALLOWED_DTYPES + if x not in ["string", "uint64"] + ] + [None] + else: + # TODO: Using uint64 will lead to weak type promotion (`float`), + # resulting in different behavior between JAX and Keras. Currently, we + # are skipping the test for uint64 + ALL_DTYPES = [ + x for x in ALLOWED_DTYPES if x not in ["string", "uint64"] + ] + [None] + + def setUp(self): + from jax.experimental import enable_x64 + + self.jax_enable_x64 = enable_x64() + self.jax_enable_x64.__enter__() + return super().setUp() + + def tearDown(self) -> None: + self.jax_enable_x64.__exit__(None, None, None) + return super().tearDown() + + @parameterized.product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + def test_add(self, dtype1, dtype2): + import jax.numpy as jnp + + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + self.assertEqual( + standardize_dtype(knp.add(x1, x2).dtype), + standardize_dtype(jnp.add(x1_jax, x2_jax).dtype), + ) + self.assertEqual( + standardize_dtype(knp.Add().symbolic_call(x1, x2).dtype), + standardize_dtype(jnp.add(x1_jax, x2_jax).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_ones(self, dtype): + import jax.numpy as jnp + + self.assertEqual( + standardize_dtype(knp.ones([2, 3], dtype=dtype).dtype), + standardize_dtype(jnp.ones([2, 3], dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype( + knp.Ones().symbolic_call([2, 3], dtype=dtype).dtype + ), + standardize_dtype(jnp.ones([2, 3], dtype=dtype).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_zeros(self, dtype): + import jax.numpy as jnp + + self.assertEqual( + standardize_dtype(knp.zeros([2, 3], dtype=dtype).dtype), + standardize_dtype(jnp.zeros([2, 3], dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype( + knp.Zeros().symbolic_call([2, 3], dtype=dtype).dtype + ), + standardize_dtype(jnp.zeros([2, 3], dtype=dtype).dtype), + ) + + @parameterized.parameters( + (10, None, 1, None), + (0, 10, 1, None), + (0, 10, 0.5, None), + (10.0, None, 1, None), + (0, 10.0, 1, None), + (0.0, 10, 1, None), + (10, None, 1, "float32"), + (10, None, 1, "int32"), + ) + def test_arange(self, start, stop, step, dtype): + import jax.numpy as jnp + + self.assertEqual( + standardize_dtype(knp.arange(start, stop, step, dtype).dtype), + standardize_dtype(jnp.arange(start, stop, step, dtype).dtype), + ) + self.assertEqual( + standardize_dtype( + knp.Arange().symbolic_call(start, stop, step, dtype).dtype + ), + standardize_dtype(jnp.arange(start, stop, step, dtype).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_empty(self, dtype): + import jax.numpy as jnp + + self.assertEqual( + standardize_dtype(knp.empty([2, 3], dtype=dtype).dtype), + standardize_dtype(jnp.empty([2, 3], dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype( + knp.Empty().symbolic_call([2, 3], dtype=dtype).dtype + ), + standardize_dtype(jnp.empty([2, 3], dtype=dtype).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_identity(self, dtype): + import jax.numpy as jnp + + if backend.backend() == "torch": + if dtype == "bfloat16": + self.skipTest( + "identity with dtype=bfloat16 is not supported for torch" + ) + + self.assertEqual( + standardize_dtype(knp.identity(3, dtype=dtype).dtype), + standardize_dtype(jnp.identity(3, dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype( + knp.Identity().symbolic_call(3, dtype=dtype).dtype + ), + standardize_dtype(jnp.identity(3, dtype=dtype).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_tri(self, dtype): + import jax.numpy as jnp + + if backend.backend() == "torch": + if dtype == "bfloat16": + self.skipTest( + "tri with dtype=bfloat16 is not supported for torch" + ) + + self.assertEqual( + standardize_dtype(knp.tri(3, dtype=dtype).dtype), + standardize_dtype(jnp.tri(3, dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype(knp.Tri().symbolic_call(3, dtype=dtype).dtype), + standardize_dtype(jnp.tri(3, dtype=dtype).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_sqrt(self, dtype): + import jax.numpy as jnp + + if backend.backend() == "torch": + if dtype == "float16": + self.skipTest( + "sqrt with dtype=float16 is not supported for torch" + ) + + x1 = knp.ones((1,), dtype=dtype) + x1_jax = jnp.ones((1,), dtype=dtype) + + self.assertEqual( + standardize_dtype(knp.sqrt(x1).dtype), + standardize_dtype(jnp.sqrt(x1_jax).dtype), + ) + self.assertEqual( + standardize_dtype(knp.Sqrt().symbolic_call(x1).dtype), + standardize_dtype(jnp.sqrt(x1_jax).dtype), + ) + + @parameterized.parameters(ALL_DTYPES) + def test_eye(self, dtype): + import jax.numpy as jnp + + if backend.backend() == "torch": + if dtype == "bfloat16": + self.skipTest( + "eye with dtype=bfloat16 is not supported for torch" + ) + + self.assertEqual( + standardize_dtype(knp.eye(3, dtype=dtype).dtype), + standardize_dtype(jnp.eye(3, dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype(knp.Eye().symbolic_call(3, dtype=dtype).dtype), + standardize_dtype(jnp.eye(3, dtype=dtype).dtype), + ) + + self.assertEqual( + standardize_dtype(knp.eye(3, 4, 1, dtype=dtype).dtype), + standardize_dtype(jnp.eye(3, 4, 1, dtype=dtype).dtype), + ) + self.assertEqual( + standardize_dtype( + knp.Eye().symbolic_call(3, 4, 1, dtype=dtype).dtype + ), + standardize_dtype(jnp.eye(3, 4, 1, dtype=dtype).dtype), + )