Skip to content

Commit

Permalink
Switch from tf.nest to dm-tree (#1199)
Browse files Browse the repository at this point in the history
keras-core already did this, makes sense to follow suit for us.
  • Loading branch information
mattdangerw authored Aug 9, 2023
1 parent 26a57c3 commit 8f1bc4e
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 27 deletions.
3 changes: 2 additions & 1 deletion keras_nlp/layers/preprocessing/preprocessing_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import tensorflow as tf
import tree

from keras_nlp.backend import config
from keras_nlp.backend import keras
Expand Down Expand Up @@ -43,7 +44,7 @@ def __call__(self, *args, **kwargs):
is_tf_backend = config.backend() == "tensorflow"
is_in_tf_graph = not tf.executing_eagerly()
if not is_tf_backend and not is_in_tf_graph:
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
convert_to_backend_tensor_or_python_list, outputs
)

Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/generative_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import itertools

import tensorflow as tf
import tree

from keras_nlp.backend import config
from keras_nlp.backend import keras
Expand Down Expand Up @@ -118,7 +119,7 @@ def wrapped_generate_function(
self.trainable_variables,
self.non_trainable_variables,
)
inputs = tf.nest.map_structure(ops.convert_to_tensor, inputs)
inputs = tree.map_structure(ops.convert_to_tensor, inputs)
outputs, state = compiled_generate_function(
inputs,
end_token_id,
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os

import keras_core
import tensorflow as tf
from rich import console as rich_console
from rich import markup
from rich import table as rich_table
Expand Down Expand Up @@ -45,7 +44,7 @@ def _check_for_loss_mismatch(self, loss):
loss, and a `None` or `"softmax"` activation.
"""
# Only handle a single loss.
if tf.nest.is_nested(loss):
if isinstance(loss, (dict, list, tuple)):
return
# Only handle tasks with activation.
if not hasattr(self, "activation"):
Expand Down
5 changes: 3 additions & 2 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Beam Sampler."""

import tensorflow as tf
import tree

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
Expand Down Expand Up @@ -141,7 +142,7 @@ def unflatten_beams(x):
cache = cache if has_cache else ()
# Add extra sequences for each beam.
prompt, mask = create_beams(prompt), create_beams(mask)
cache = tf.nest.map_structure(create_beams, cache)
cache = tree.map_structure(create_beams, cache)
# Setup the initial beam log-likelihoods.
# On the first loop, make sure only the original beam is considered.
log_probs = ops.array(
Expand Down Expand Up @@ -192,7 +193,7 @@ def gather_beams(x):

prompt = gather_beams(prompt)
if has_cache:
cache = tf.nest.map_structure(gather_beams, cache)
cache = tree.map_structure(gather_beams, cache)

# Update each beam with the next token.
next_token = ops.cast(next_token, prompt.dtype)
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/samplers/contrastive_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Contrastive Sampler."""

import tensorflow as tf
import tree

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
Expand Down Expand Up @@ -143,7 +143,7 @@ def body(prompt, cache, index, logits, hidden_states):
hidden_states_beams = create_beams(hidden_states)
cache_beams = None
if has_cache:
cache_beams = tf.nest.map_structure(create_beams, cache)
cache_beams = tree.map_structure(create_beams, cache)

# Get top-k candidate tokens and their probabilities.
top_k_probabilities, top_k_indices = ops.top_k(
Expand Down Expand Up @@ -213,8 +213,8 @@ def gather_best_token(beams):
logits = gather_best_token(unflat_next_logits)
next_hidden_states = gather_best_token(unflat_next_hidden_states)
if has_cache:
cache = tf.nest.map_structure(unflatten_beams, cache_beams)
cache = tf.nest.map_structure(gather_best_token, cache)
cache = tree.map_structure(unflatten_beams, cache_beams)
cache = tree.map_structure(gather_best_token, cache)

hidden_states = ops.slice_update(
hidden_states,
Expand Down
21 changes: 14 additions & 7 deletions keras_nlp/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import tree
from absl.testing import parameterized

from keras_nlp.backend import ops
Expand All @@ -28,7 +29,7 @@ def convert_to_comparible_type(x):
x = x.to_list()
if isinstance(x, tf.Tensor):
x = x.numpy() if x.shape.rank == 0 else x.numpy().tolist()
return tf.nest.map_structure(lambda x: x.decode("utf-8"), x)
return tree.map_structure(lambda x: x.decode("utf-8"), x)
if isinstance(x, (tf.Tensor, tf.RaggedTensor)):
return x
if ops.is_tensor(x):
Expand All @@ -40,16 +41,22 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
"""Base test case class for KerasNLP."""

def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
x1 = tf.nest.map_structure(convert_to_comparible_type, x1)
x2 = tf.nest.map_structure(convert_to_comparible_type, x2)
# This metric dict hack is only needed for tf.keras, and can be
# removed after we fully migrate to keras-core/Keras 3.
if x1.__class__.__name__ == "_MetricDict":
x1 = dict(x1)
if x2.__class__.__name__ == "_MetricDict":
x2 = dict(x2)
x1 = tree.map_structure(convert_to_comparible_type, x1)
x2 = tree.map_structure(convert_to_comparible_type, x2)
super().assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg)

def assertEqual(self, x1, x2, msg=None):
x1 = tf.nest.map_structure(convert_to_comparible_type, x1)
x2 = tf.nest.map_structure(convert_to_comparible_type, x2)
x1 = tree.map_structure(convert_to_comparible_type, x1)
x2 = tree.map_structure(convert_to_comparible_type, x2)
super().assertEqual(x1, x2, msg=msg)

def assertAllEqual(self, x1, x2, msg=None):
x1 = tf.nest.map_structure(convert_to_comparible_type, x1)
x2 = tf.nest.map_structure(convert_to_comparible_type, x2)
x1 = tree.map_structure(convert_to_comparible_type, x1)
x2 = tree.map_structure(convert_to_comparible_type, x2)
super().assertAllEqual(x1, x2, msg=msg)
9 changes: 5 additions & 4 deletions keras_nlp/utils/pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math

import tensorflow as tf
import tree

from keras_nlp.backend import keras
from keras_nlp.backend import ops
Expand Down Expand Up @@ -65,7 +66,7 @@ def convert(x):
return ops.convert_to_numpy(x)
return x

inputs = tf.nest.map_structure(convert, inputs)
inputs = tree.map_structure(convert, inputs)
ds = tf.data.Dataset.from_tensor_slices(inputs)
except ValueError as e:
# If our inputs are unbatched, re-raise with a more friendly error
Expand All @@ -92,7 +93,7 @@ def _train_validation_split(arrays, validation_split):
def _can_split(t):
return is_tensor_type(t) or t is None

flat_arrays = tf.nest.flatten(arrays)
flat_arrays = tree.flatten(arrays)
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
if unsplitable:
raise ValueError(
Expand Down Expand Up @@ -129,10 +130,10 @@ def _split(t, start, end):
return t
return t[start:end]

train_arrays = tf.nest.map_structure(
train_arrays = tree.map_structure(
functools.partial(_split, start=0, end=split_at), arrays
)
val_arrays = tf.nest.map_structure(
val_arrays = tree.map_structure(
functools.partial(_split, start=split_at, end=batch_dim), arrays
)

Expand Down
12 changes: 6 additions & 6 deletions requirements-common.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# Library deps.
keras-core
# Tooling.
dm-tree
regex
rich
# Tooling deps.
astor
numpy~=1.23.2 # Numpy 1.24 breaks tests on ragged tensors
packaging
black>=22
black[jupyter]
black[jupyter]>=22
flake8
isort
pytest
pytest-cov
build
namex
regex
rich
# Optional deps.
rouge-score
sentencepiece
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_version(rel_path):
"packaging",
"regex",
"rich",
"dm-tree",
# Don't require tensorflow-text on MacOS, there are no binaries for ARM.
# Also, we rely on tensorflow *transitively* through tensorflow-text.
# This avoid a slowdown during `pip install keras-nlp` where pip would
Expand Down

0 comments on commit 8f1bc4e

Please sign in to comment.