Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce backend.result_type #18482

Merged
merged 29 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
26ad3bc
Add `result_dtype` and some refactor of `ops.numpy`
james77777778 Sep 23, 2023
c206043
Fix keras_export
james77777778 Sep 23, 2023
b5a7944
Merge branch 'keras-team:master' into introduce-dtype-inference
james77777778 Sep 25, 2023
fc34bb7
Refactor `result_dtype`
james77777778 Sep 25, 2023
1df20fc
Merge branch 'keras-team:master' into introduce-dtype-inference
james77777778 Sep 26, 2023
1f8e623
Update `result_type`
james77777778 Sep 26, 2023
3890ab0
Revert `ops.numpy.*` changes
james77777778 Sep 26, 2023
19c88a9
Merge branch 'keras-team:master' into introduce-dtype-inference
james77777778 Sep 26, 2023
eb02511
ensure consistent dtype inference
james77777778 Sep 26, 2023
2e86e59
add dtype test
james77777778 Sep 26, 2023
d9a049f
fix dropout rnn test
james77777778 Sep 26, 2023
e2df127
Fix symbolic test
james77777778 Sep 26, 2023
dde0cc0
Fix torch test
james77777778 Sep 26, 2023
25b94a2
keep `"int64"` when using tensorflow
james77777778 Sep 26, 2023
4b0eaba
Fix test
james77777778 Sep 26, 2023
61c958c
Simplify `result_type` for tensorflow
james77777778 Sep 26, 2023
14d8d5e
Merge branch 'keras-team:master' into introduce-dtype-inference
james77777778 Sep 27, 2023
ad76c46
Add `pre_canonicalize` option, rename to `result_type`
james77777778 Sep 27, 2023
53b84dd
Align the behavior of `ops.add`
james77777778 Sep 27, 2023
35ecec4
Fix test
james77777778 Sep 27, 2023
072cc55
Merge branch 'keras-team:master' into introduce-dtype-inference
james77777778 Sep 28, 2023
8ebff53
Match `backend.result_type` to JAX with `JAX_ENABLE_X64=true` and `JA…
james77777778 Sep 28, 2023
2d014f0
Use `dtype or config.floatx()`
james77777778 Sep 28, 2023
4e5f67a
Fix symbolic ops
james77777778 Sep 28, 2023
903dbec
Remove `result_type` in jax and torch
james77777778 Sep 28, 2023
e2883c8
Address comments
james77777778 Sep 28, 2023
3b20fec
Apply `result_type` to `ops.numpy.arange`
james77777778 Sep 28, 2023
42be5d2
Apply `backend.result_type` to `ops.numpy.sqrt`
james77777778 Sep 28, 2023
6bb815f
Skip float16 test for torch's sqrt
james77777778 Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# upon import.
import torch

from keras.backend.common.dtypes import result_dtype
from keras.backend.common.keras_tensor import KerasTensor
from keras.backend.common.keras_tensor import any_symbolic_tensors
from keras.backend.common.keras_tensor import is_keras_tensor
Expand Down
1 change: 1 addition & 0 deletions keras/backend/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras.backend.common import backend_utils
from keras.backend.common.dtypes import result_dtype
from keras.backend.common.variables import AutocastScope
from keras.backend.common.variables import KerasVariable
from keras.backend.common.variables import get_autocast_scope
Expand Down
297 changes: 297 additions & 0 deletions keras/backend/common/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
import functools

from keras import backend
from keras.api_export import keras_export
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.common.variables import standardize_dtype

"""
We adapted the type promotion lattice from JAX. Ref:
https://github.com/google/jax/blob/main/jax/_src/dtypes.py
"""

BOOL_TYPES = ["bool"]
INT_TYPES = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
]
FLOAT_TYPES = ["bfloat16", "float16", "float32", "float64"]
WEAK_TYPES = ["int", "float"]


def _type_promotion_lattice():
"""
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
"""
(b1,) = BOOL_TYPES
(u1, u2, u4, u8, i1, i2, i4, i8) = INT_TYPES
bf, f2, f4, f8 = FLOAT_TYPES
i_, f_ = WEAK_TYPES
out = {
b1: [i_],
u1: [i2, u2],
u2: [i4, u4],
u4: [i8, u8],
u8: [f_],
i_: [u1, i1],
i1: [i2],
i2: [i4],
i4: [i8],
i8: [f_],
f_: [bf, f2],
bf: [f4],
f2: [f4],
f4: [f8],
f8: [],
}
return out


def _make_lattice_upper_bounds():
lattice = _type_promotion_lattice()
upper_bounds = {node: {node} for node in lattice}
for n in lattice:
while True:
new_upper_bounds = set().union(
*(lattice[b] for b in upper_bounds[n])
)
if n in new_upper_bounds:
raise ValueError(
f"cycle detected in type promotion lattice for node {n}"
)
if new_upper_bounds.issubset(upper_bounds[n]):
break
upper_bounds[n] |= new_upper_bounds
return upper_bounds


LATTICE_UPPER_BOUNDS = _make_lattice_upper_bounds()


@functools.lru_cache(512)
def _least_upper_bound(*nodes):
"""Compute the least upper bound of a set of nodes.

Args:
nodes: sequence of entries from dtypes + weak_types

Returns:
The type representing the least upper bound of the input nodes on the
promotion lattice.
"""
# This function computes the least upper bound of a set of nodes N within a
# partially ordered set defined by the lattice generated above.
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
# UB(n) ≡ {m ∈ S | n ≤ m}
# Further, for a set of nodes N ⊆ S, let the set of common upper bounds be
# given by
# CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)}
# Then the least upper bound of N is defined as
# LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d}
# The definition of an upper bound implies that
# c ≤ d if and only if d ∈ UB(c),
# so the LUB can be expressed:
# LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)}
# or, equivalently:
# LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)}
# By definition, LUB(N) has a cardinality of 1 for a partially ordered set.
# Note a potential algorithmic shortcut: from the definition of CUB(N),
# we have
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
N = set(nodes)
UB = LATTICE_UPPER_BOUNDS
try:
bounds = [UB[n] for n in N]
except KeyError:
dtype = next(n for n in N if n not in UB)
raise ValueError(
f"{dtype=} is not a valid dtype for Keras type promotion."
)
CUB = set.intersection(*bounds)
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
if len(LUB) == 1:
return LUB.pop()
elif len(LUB) == 0:
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available "
"implicit dtype promotion path. Try explicitly casting inputs to "
"the desired output type."
)
raise ValueError(msg)
else:
# If we get here, it means the lattice is ill-formed.
raise ValueError(
f"Internal Type Promotion error: {nodes} do not have a unique "
f"least upper bound on the specified lattice; options are {LUB}. "
"This is an unexpected error in Keras Core's internal logic; "
james77777778 marked this conversation as resolved.
Show resolved Hide resolved
"please report it to the maintainers."
)


def _dtype_and_weaktype(value):
"""Return a (dtype, weak_type) tuple for the given input."""
is_weak_type = False
if value is int or value is float:
# Note that we can't use `value in [int, float]` because the dtype
# might be equal to python scalar types.
# e.g, tf.float32 == float is True
is_weak_type = True
return standardize_dtype(value), is_weak_type


@functools.lru_cache(maxsize=None)
def _respect_weak_type(dtype, weak_type):
"""Return the weak dtype of `dtype` if `weak_type==True`."""
if weak_type:
if dtype == "bool":
return dtype
elif "float" in dtype:
return "float"
elif "int" in dtype:
return "int"
else:
raise ValueError(
"Invalid value for argument `dtype`. Expected one of "
f"{ALLOWED_DTYPES}. Received: dtype={dtype}"
)
return dtype


@functools.lru_cache(maxsize=None)
def _resolve_weak_type(dtype, precision="32"):
"""Resolve weak type by the precision of `backend.floatx()`."""
extended_allowed_dtypes = ALLOWED_DTYPES.union(WEAK_TYPES)
if dtype not in extended_allowed_dtypes:
raise ValueError(
"Invalid value for argument `dtype`. Expected one of "
f"{extended_allowed_dtypes}. Received: dtype={dtype}"
)
if precision not in ["16", "32", "64"]:
raise ValueError(
f"Invalid value for argument `precision`. Expected one of "
f"('16', '32', '64'). Received: precision={precision}"
)
if dtype == "bfloat16": # special case for bfloat16
dtype_indicator = "f"
else:
dtype_indicator = dtype[:1]

if dtype_indicator == "b":
return "bool"
elif dtype_indicator == "i":
return "int" + precision
elif dtype_indicator == "u":
return "uint" + precision
else:
return "float" + precision


BIT64_TO_BIT16_DTYPE = {
"int32": "int16",
"int64": "int16",
"uint32": "uint16",
"uint64": "uint16",
"float32": "float16",
"float64": "float16",
}
BIT64_TO_BIT32_DTYPE = {
"int64": "int32",
"uint64": "uint32",
"float64": "float32",
}


@functools.lru_cache(maxsize=None)
def _canonicalize_dtype_by_precision(dtype, precision="32"):
"""Canonicalize dtype by the precision of `backend.floatx()`."""
if precision == "16":
return BIT64_TO_BIT16_DTYPE.get(dtype, dtype)
elif precision == "32":
return BIT64_TO_BIT32_DTYPE.get(dtype, dtype)
elif precision == "64":
return dtype
else:
raise ValueError(
f"Invalid value for argument `precision`. Expected one of "
f"('16', '32', '64'). Received: precision={precision}"
)


def _lattice_result_type(*args):
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
out_dtype = dtypes[0]
out_weak_type = weak_types[0]
elif len(set(dtypes)) == 1 and not all(weak_types):
# Trivial promotion case. This allows extended dtypes through.
out_dtype = dtypes[0]
out_weak_type = False
elif all(weak_types):
# If all inputs are weakly typed, we compute the bound of the
# strongly-typed counterparts and apply the weak type at the end. This
# avoids returning the incorrect result with non-canonical weak types
# (e.g. weak int16).
out_dtype = _least_upper_bound(
*{_respect_weak_type(d, False) for d in dtypes}
)
out_weak_type = True
else:
nodes = []
for d, w in zip(dtypes, weak_types):
nodes.append(_respect_weak_type(d, w))
out_dtype = _least_upper_bound(*nodes)
out_weak_type = any(out_dtype is t for t in WEAK_TYPES)

out_weak_type = (out_dtype != "bool") and out_weak_type
if out_weak_type:
out_dtype = _resolve_weak_type(out_dtype)
precision = backend.floatx()[-2:]
out_dtype = _canonicalize_dtype_by_precision(out_dtype, precision)
return out_dtype


@keras_export("keras.backend.result_dtype")
def result_dtype(*tensors_and_dtypes):
james77777778 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the type from applying the Keras type promotion rules.

In general, each argument is first parsed by `backend.standardize_dtype`,
and the resulting dtype is determined by the least upper bound of the type
promotion lattice.

Note: This function attempts to match the result of `jnp.result_dtype`.

Args:
tensors_and_dtypes: Input arguments.

Returns:
The result dtype.

Examples:

>>> x = keras.ops.ones((1,), dtype="bfloat16")
>>> keras.backend.result_dtype(x.dtype, int)
"bfloat16"

>>> x = keras.ops.ones((1,), dtype="int32")
>>> y = keras.ops.ones((1,), dtype="float32")
james77777778 marked this conversation as resolved.
Show resolved Hide resolved
"float32"
"""
if len(tensors_and_dtypes) == 0:
raise ValueError(
"Invalid `tensors_and_dtypes`. At least one tensor or dtype is "
f"required. Received: tensors_and_dtypes={tensors_and_dtypes}"
)
return _lattice_result_type(
*(
backend.floatx() if arg is None else arg
for arg in tensors_and_dtypes
)
)
44 changes: 44 additions & 0 deletions keras/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from absl.testing import parameterized

from keras import backend
from keras import ops
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.torch.core import to_torch_dtype
from keras.testing import test_case


class DtypesTest(test_case.TestCase, parameterized.TestCase):
"""Dtypes test to verify that the result type matches `jnp.result_type`."""

@parameterized.product(
dtype1=[d for d in ALLOWED_DTYPES if d != "string"],
dtype2=[bool, int, float],
)
def test_result_dtype_with_python_scalar_types(self, dtype1, dtype2):
import jax.numpy as jnp

out = backend.result_dtype(dtype1, dtype2)
expected = jnp.result_type(dtype1, dtype2).name
self.assertEqual(out, expected)

@parameterized.product(
# TODO: uint64, int64 and float64 are not supported by JAX by default
dtype1=[d for d in ALLOWED_DTYPES if d != "string" and "64" not in d],
dtype2=[d for d in ALLOWED_DTYPES if d != "string" and "64" not in d],
)
def test_result_dtype_with_tensor(self, dtype1, dtype2):
# TODO: torch doesn't have `uint16` and `uint32` dtypes
if backend.backend() == "torch":
dtype1 = str(to_torch_dtype(dtype1)).split(".")[-1]
dtype2 = str(to_torch_dtype(dtype2)).split(".")[-1]

import jax.numpy as jnp

x1 = ops.ones((1,), dtype=dtype1)
x2 = ops.ones((1,), dtype=dtype2)
x1_jax = jnp.ones((1,), dtype=dtype1)
x2_jax = jnp.ones((1,), dtype=dtype2)

out = backend.result_dtype(x1.dtype, x2.dtype)
expected = jnp.result_type(x1_jax, x2_jax).name
self.assertEqual(out, expected)
Loading
Loading