From 8f4f8574c78d60b8ab28891b182d459963d6f8fd Mon Sep 17 00:00:00 2001 From: Hazem Date: Sat, 15 Jul 2023 21:25:10 +0300 Subject: [PATCH] Removed CategoryEncoding that was implemented in keras_core.ops.nn and changed the implementation of keras_core.layers.CategoryEncoding to be backend-agnostic instead of TensorFlow implementation --- keras_core/backend/jax/nn.py | 8 ++++++ keras_core/backend/tensorflow/nn.py | 8 ++++++ keras_core/backend/torch/nn.py | 8 ++++++ .../layers/preprocessing/category_encoding.py | 26 ++++++++----------- keras_core/ops/nn.py | 24 ----------------- 5 files changed, 35 insertions(+), 39 deletions(-) diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index db10f451d..50d1eb7c1 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -391,6 +391,14 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"): return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype) +def multi_hot(x, num_classes, axis=-1, dtype='float32'): + return jax.numpy.max(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1) + + +def count(x, num_classes, axis=-1, dtype='float32'): + return jax.numpy.sum(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1) + + def categorical_crossentropy(target, output, from_logits=False, axis=-1): target = jnp.array(target) output = jnp.array(output) diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index 1879e9e52..b020470ac 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -420,6 +420,14 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"): return tf.one_hot(x, num_classes, axis=axis, dtype=dtype) +def multi_hot(x, num_classes, axis=-1, dtype='float32'): + return tf.reduce_max(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1) + + +def count(x, num_classes, axis=-1, dtype='float32'): + return tf.reduce_sum(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1) + + def _get_logits(output, from_logits, op_type, fn_name): """Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor.""" output_ = output diff --git a/keras_core/backend/torch/nn.py b/keras_core/backend/torch/nn.py index 59c694a80..8bdcf739a 100644 --- a/keras_core/backend/torch/nn.py +++ b/keras_core/backend/torch/nn.py @@ -538,6 +538,14 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"): return output +def multi_hot(x, num_classes, axis=-1, dtype='float32'): + return torch.amax(one_hot(x, num_classes, axis=axis, dtype=dtype), dim=1) + + +def count(x, num_classes, axis=-1, dtype='float32'): + return torch.sum(one_hot(x, num_classes, axis=axis, dtype=dtype), dim=1) + + def categorical_crossentropy(target, output, from_logits=False, axis=-1): target = convert_to_tensor(target) output = convert_to_tensor(output) diff --git a/keras_core/layers/preprocessing/category_encoding.py b/keras_core/layers/preprocessing/category_encoding.py index acf6e2d44..d89ca5c23 100644 --- a/keras_core/layers/preprocessing/category_encoding.py +++ b/keras_core/layers/preprocessing/category_encoding.py @@ -2,7 +2,7 @@ from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer from keras_core.utils import backend_utils -from keras_core.utils.module_utils import tensorflow as tf + @keras_core_export("keras_core.layers.CategoryEncoding") @@ -84,21 +84,23 @@ class CategoryEncoding(Layer): """ def __init__(self, num_tokens=None, output_mode="multi_hot", **kwargs): - if not tf.available: - raise ImportError( - "Layer CategoryEncoding requires TensorFlow. " - "Install it via `pip install tensorflow`." - ) - super().__init__(**kwargs) # Support deprecated names for output_modes. if output_mode == "binary": output_mode = "multi_hot" + # 'output_mode' must be one of ("count", "one_hot", "multi_hot") if output_mode not in ("count", "one_hot", "multi_hot"): raise ValueError(f"Unknown arg for output_mode: {output_mode}") + if output_mode == 'multi_hot': + self.output_function = backend.nn.multi_hot + elif output_mode == 'count': + self.output_function = backend.nn.count + elif output_mode == 'one_hot': + self.output_function = backend.nn.one_hot + if num_tokens is None: raise ValueError( "num_tokens must be set to use this layer. If the " @@ -111,17 +113,11 @@ def __init__(self, num_tokens=None, output_mode="multi_hot", **kwargs): ) self.num_tokens = num_tokens self.output_mode = output_mode - - self.layer = tf.keras.layers.CategoryEncoding( - num_tokens=num_tokens, - output_mode=output_mode, - **kwargs, - ) self._allow_non_tensor_positional_args = True self._convert_input_args = False def compute_output_shape(self, input_shape): - return tuple(self.layer.compute_output_shape(input_shape)) + return tuple(input_shape + (self.num_classes,)) def get_config(self): config = { @@ -132,7 +128,7 @@ def get_config(self): return {**base_config, **config} def call(self, inputs): - outputs = self.layer.call(inputs) + outputs = self.output_function(inputs, self.num_tokens, dtype=self.dtype) if ( backend.backend() != "tensorflow" and not backend_utils.in_tf_graph() diff --git a/keras_core/ops/nn.py b/keras_core/ops/nn.py index e3513f601..5ac48c88a 100644 --- a/keras_core/ops/nn.py +++ b/keras_core/ops/nn.py @@ -1089,27 +1089,3 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): return backend.nn.sparse_categorical_crossentropy( target, output, from_logits=from_logits, axis=axis ) - - -class CategoryEncoding(Operation): - def __init__(self, num_classes, name=None): - super().__init__(name) - self.num_classes = num_classes - - def call(self, inputs): - return backend.nn.one_hot(inputs, self.num_classes) - - def compute_output_spec(self, inputs): - return KerasTensor( - inputs.shape + (self.num_classes,), dtype=inputs.dtype - ) - - -@keras_core_export( - ["keras_core.ops.category_encoding", "keras_core.ops.nn.category_encoding"] -) -def category_encoding(inputs, num_classes=None): - # TODO: add docstring - if any_symbolic_tensors((inputs,)): - return CategoryEncoding(num_classes=num_classes).symbolic_call(inputs) - return backend.nn.one_hot(inputs, num_classes)