Skip to content

Commit

Permalink
Introduce backend.result_type (#18482)
Browse files Browse the repository at this point in the history
* Add `result_dtype` and some refactor of `ops.numpy`

* Fix keras_export

* Refactor `result_dtype`

* Update `result_type`

* Revert `ops.numpy.*` changes

* ensure consistent dtype inference

* add dtype test

* fix dropout rnn test

* Fix symbolic test

* Fix torch test

* keep `"int64"` when using tensorflow

* Fix test

* Simplify `result_type` for tensorflow

* Add `pre_canonicalize` option, rename to `result_type`

* Align the behavior of `ops.add`

* Fix test

* Match `backend.result_type` to JAX with `JAX_ENABLE_X64=true` and `JAX_DEFAULT_DTYPE_BITS=32`

* Use `dtype or config.floatx()`

* Fix symbolic ops

* Remove `result_type` in jax and torch

* Address comments

* Apply `result_type` to `ops.numpy.arange`

* Apply `backend.result_type` to `ops.numpy.sqrt`

* Skip float16 test for torch's sqrt
  • Loading branch information
james77777778 authored Sep 28, 2023
1 parent fab88ec commit f73df5a
Show file tree
Hide file tree
Showing 12 changed files with 749 additions and 87 deletions.
10 changes: 8 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os

# When using jax.experimental.enable_x64 in unit test, we want to keep the
# default dtype with 32 bits, aligning it with Keras's default.
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"

try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import. This should force the torch
Expand All @@ -6,9 +12,9 @@
except ImportError:
pass

import pytest
import pytest # noqa: E402

from keras.backend import backend
from keras.backend import backend # noqa: E402


def pytest_configure(config):
Expand Down
1 change: 1 addition & 0 deletions keras/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# upon import.
import torch

from keras.backend.common.dtypes import result_type
from keras.backend.common.keras_tensor import KerasTensor
from keras.backend.common.keras_tensor import any_symbolic_tensors
from keras.backend.common.keras_tensor import is_keras_tensor
Expand Down
1 change: 1 addition & 0 deletions keras/backend/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras.backend.common import backend_utils
from keras.backend.common.dtypes import result_type
from keras.backend.common.variables import AutocastScope
from keras.backend.common.variables import KerasVariable
from keras.backend.common.variables import get_autocast_scope
Expand Down
277 changes: 277 additions & 0 deletions keras/backend/common/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import functools

from keras import backend
from keras.api_export import keras_export
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.common.variables import standardize_dtype

"""
We adapted the type promotion lattice from JAX. Ref:
https://github.com/google/jax/blob/main/jax/_src/dtypes.py
"""

BOOL_TYPES = ["bool"]
INT_TYPES = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
]
FLOAT_TYPES = ["bfloat16", "float16", "float32", "float64"]
WEAK_TYPES = ["int", "float"]


def _type_promotion_lattice():
"""
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
"""
(b1,) = BOOL_TYPES
(u1, u2, u4, u8, i1, i2, i4, i8) = INT_TYPES
bf, f2, f4, f8 = FLOAT_TYPES
i_, f_ = WEAK_TYPES
out = {
b1: [i_],
u1: [i2, u2],
u2: [i4, u4],
u4: [i8, u8],
u8: [f_],
i_: [u1, i1],
i1: [i2],
i2: [i4],
i4: [i8],
i8: [f_],
f_: [bf, f2],
bf: [f4],
f2: [f4],
f4: [f8],
f8: [],
}
return out


def _make_lattice_upper_bounds():
lattice = _type_promotion_lattice()
upper_bounds = {node: {node} for node in lattice}
for n in lattice:
while True:
new_upper_bounds = set().union(
*(lattice[b] for b in upper_bounds[n])
)
if n in new_upper_bounds:
raise ValueError(
f"cycle detected in type promotion lattice for node {n}"
)
if new_upper_bounds.issubset(upper_bounds[n]):
break
upper_bounds[n] |= new_upper_bounds
return upper_bounds


LATTICE_UPPER_BOUNDS = _make_lattice_upper_bounds()


@functools.lru_cache(512)
def _least_upper_bound(*nodes):
"""Compute the least upper bound of a set of nodes.
Args:
nodes: sequence of entries from dtypes + weak_types
Returns:
The type representing the least upper bound of the input nodes on the
promotion lattice.
"""
# This function computes the least upper bound of a set of nodes N within a
# partially ordered set defined by the lattice generated above.
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
# UB(n) ≡ {m ∈ S | n ≤ m}
# Further, for a set of nodes N ⊆ S, let the set of common upper bounds be
# given by
# CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)}
# Then the least upper bound of N is defined as
# LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d}
# The definition of an upper bound implies that
# c ≤ d if and only if d ∈ UB(c),
# so the LUB can be expressed:
# LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)}
# or, equivalently:
# LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)}
# By definition, LUB(N) has a cardinality of 1 for a partially ordered set.
# Note a potential algorithmic shortcut: from the definition of CUB(N),
# we have
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
N = set(nodes)
UB = LATTICE_UPPER_BOUNDS
try:
bounds = [UB[n] for n in N]
except KeyError:
dtype = next(n for n in N if n not in UB)
raise ValueError(
f"{dtype=} is not a valid dtype for Keras type promotion."
)
CUB = set.intersection(*bounds)
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
if len(LUB) == 1:
return LUB.pop()
elif len(LUB) == 0:
msg = (
f"Input dtypes {tuple(str(n) for n in nodes)} have no available "
"implicit dtype promotion path. Try explicitly casting inputs to "
"the desired output type."
)
raise ValueError(msg)
else:
# If we get here, it means the lattice is ill-formed.
raise ValueError(
f"Internal Type Promotion error: {nodes} do not have a unique "
f"least upper bound on the specified lattice; options are {LUB}. "
"This is an unexpected error in Keras's internal logic; "
"please report it to the maintainers."
)


def _dtype_and_weaktype(value):
"""Return a (dtype, weak_type) tuple for the given input."""
is_weak_type = False
if value is int or value is float:
# Note that we can't use `value in [int, float]` because the dtype
# might be equal to python scalar types.
# e.g, tf.float32 == float is True
is_weak_type = True
return standardize_dtype(value), is_weak_type


@functools.lru_cache(maxsize=None)
def _respect_weak_type(dtype, weak_type):
"""Return the weak dtype of `dtype` if `weak_type==True`."""
if weak_type:
if dtype == "bool":
return dtype
elif "float" in dtype:
return "float"
elif "int" in dtype:
return "int"
else:
raise ValueError(
"Invalid value for argument `dtype`. Expected one of "
f"{ALLOWED_DTYPES}. Received: dtype={dtype}"
)
return dtype


@functools.lru_cache(maxsize=None)
def _resolve_weak_type(dtype, precision="32"):
"""Resolve weak type by the precision of `backend.floatx()`."""
extended_allowed_dtypes = ALLOWED_DTYPES.union(WEAK_TYPES)
if dtype not in extended_allowed_dtypes:
raise ValueError(
"Invalid value for argument `dtype`. Expected one of "
f"{extended_allowed_dtypes}. Received: dtype={dtype}"
)
if precision not in ["16", "32", "64"]:
raise ValueError(
f"Invalid value for argument `precision`. Expected one of "
f"('16', '32', '64'). Received: precision={precision}"
)
if dtype == "bfloat16": # special case for bfloat16
dtype_indicator = "f"
else:
dtype_indicator = dtype[:1]

if dtype_indicator == "b":
return "bool"
elif dtype_indicator == "i":
return "int" + precision
elif dtype_indicator == "u":
return "uint" + precision
else:
return "float" + precision


BIT64_TO_BIT16_DTYPE = {
"int32": "int16",
"int64": "int16",
"uint32": "uint16",
"uint64": "uint16",
"float32": "float16",
"float64": "float16",
}
BIT64_TO_BIT32_DTYPE = {
"int64": "int32",
"uint64": "uint32",
"float64": "float32",
}


def _lattice_result_type(*args):
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
out_dtype = dtypes[0]
out_weak_type = weak_types[0]
elif len(set(dtypes)) == 1 and not all(weak_types):
# Trivial promotion case. This allows extended dtypes through.
out_dtype = dtypes[0]
out_weak_type = False
elif all(weak_types):
# If all inputs are weakly typed, we compute the bound of the
# strongly-typed counterparts and apply the weak type at the end. This
# avoids returning the incorrect result with non-canonical weak types
# (e.g. weak int16).
out_dtype = _least_upper_bound(
*{_respect_weak_type(d, False) for d in dtypes}
)
out_weak_type = True
else:
out_dtype = _least_upper_bound(
*{_respect_weak_type(d, w) for d, w in zip(dtypes, weak_types)}
)
out_weak_type = any(out_dtype is t for t in WEAK_TYPES)

out_weak_type = (out_dtype != "bool") and out_weak_type
precision = backend.floatx()[-2:]
if out_weak_type:
out_dtype = _resolve_weak_type(out_dtype, precision=precision)
return out_dtype


@keras_export("keras.backend.result_type")
def result_type(*dtypes):
"""Returns the type from applying the Keras type promotion rules.
In general, each argument is first parsed by `backend.standardize_dtype`,
and the resulting dtype is determined by the least upper bound of the type
promotion lattice.
Note: This function attempts to match the result of `jnp.result_type`.
Args:
dtypes: Input dtypes.
Returns:
The result dtype.
Examples:
>>> x = keras.ops.ones((1,), dtype="bfloat16")
>>> keras.backend.result_type(x.dtype, int)
"bfloat16"
>>> x = keras.ops.ones((1,), dtype="int32")
>>> y = keras.ops.ones((1,), dtype="float32")
>>> keras.backend.result_type(x.dtype, y.dtype)
"float32"
"""
if len(dtypes) == 0:
raise ValueError(
"Invalid `dtypes`. At least one dtype is required. "
f"Received: dtypes={dtypes}"
)
return _lattice_result_type(
*(backend.floatx() if arg is None else arg for arg in dtypes),
)
64 changes: 64 additions & 0 deletions keras/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from absl.testing import parameterized

from keras import backend
from keras import ops
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.torch.core import to_torch_dtype
from keras.testing import test_case


class DtypesTest(test_case.TestCase, parameterized.TestCase):
"""Test the dtype to verify that the behavior matches JAX."""

if backend.backend() == "torch":
# TODO: torch doesn't support uint64.
ALL_DTYPES = [
str(to_torch_dtype(x)).split(".")[-1]
for x in ALLOWED_DTYPES
if x not in ["string", "uint64"]
] + [None]
else:
ALL_DTYPES = [x for x in ALLOWED_DTYPES if x != "string"] + [None]

def setUp(self):
from jax.experimental import enable_x64

self.jax_enable_x64 = enable_x64()
self.jax_enable_x64.__enter__()
return super().setUp()

def tearDown(self) -> None:
self.jax_enable_x64.__exit__(None, None, None)
return super().tearDown()

@parameterized.product(dtype1=ALL_DTYPES, dtype2=[bool, int, float])
def test_result_type_with_python_scalar_types(self, dtype1, dtype2):
import jax.numpy as jnp

out = backend.result_type(dtype1, dtype2)
expected = jnp.result_type(dtype1, dtype2).name
self.assertEqual(out, expected)

@parameterized.product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES)
def test_result_type_with_tensor(self, dtype1, dtype2):
import jax.numpy as jnp

x1 = ops.ones((1,), dtype=dtype1)
x2 = ops.ones((1,), dtype=dtype2)
x1_jax = jnp.ones((1,), dtype=dtype1)
x2_jax = jnp.ones((1,), dtype=dtype2)

out = backend.result_type(x1.dtype, x2.dtype)
expected = jnp.result_type(x1_jax, x2_jax).name
self.assertEqual(out, expected)

def test_result_type_with_none(self):
import jax.numpy as jnp

self.assertEqual(backend.result_type(None), jnp.result_type(None).name)

def test_result_type_invalid_dtypes(self):
with self.assertRaisesRegexp(
ValueError, "Invalid `dtypes`. At least one dtype is required."
):
backend.result_type()
Loading

0 comments on commit f73df5a

Please sign in to comment.