Skip to content

Commit

Permalink
Removed CategoryEncoding that was implemented in keras_core.ops.nn an…
Browse files Browse the repository at this point in the history
…d changed the implementation of keras_core.layers.CategoryEncoding to be backend-agnostic instead of TensorFlow implementation
  • Loading branch information
hazemessamm committed Jul 15, 2023
1 parent 10725e3 commit 8f4f857
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 39 deletions.
8 changes: 8 additions & 0 deletions keras_core/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,14 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"):
return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype)


def multi_hot(x, num_classes, axis=-1, dtype='float32'):
return jax.numpy.max(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1)


def count(x, num_classes, axis=-1, dtype='float32'):
return jax.numpy.sum(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1)


def categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = jnp.array(target)
output = jnp.array(output)
Expand Down
8 changes: 8 additions & 0 deletions keras_core/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,14 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"):
return tf.one_hot(x, num_classes, axis=axis, dtype=dtype)


def multi_hot(x, num_classes, axis=-1, dtype='float32'):
return tf.reduce_max(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1)


def count(x, num_classes, axis=-1, dtype='float32'):
return tf.reduce_sum(one_hot(x, num_classes, axis=axis, dtype=dtype), axis=1)


def _get_logits(output, from_logits, op_type, fn_name):
"""Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor."""
output_ = output
Expand Down
8 changes: 8 additions & 0 deletions keras_core/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,14 @@ def one_hot(x, num_classes, axis=-1, dtype="float32"):
return output


def multi_hot(x, num_classes, axis=-1, dtype='float32'):
return torch.amax(one_hot(x, num_classes, axis=axis, dtype=dtype), dim=1)


def count(x, num_classes, axis=-1, dtype='float32'):
return torch.sum(one_hot(x, num_classes, axis=axis, dtype=dtype), dim=1)


def categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = convert_to_tensor(target)
output = convert_to_tensor(output)
Expand Down
26 changes: 11 additions & 15 deletions keras_core/layers/preprocessing/category_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
from keras_core.utils.module_utils import tensorflow as tf



@keras_core_export("keras_core.layers.CategoryEncoding")
Expand Down Expand Up @@ -84,21 +84,23 @@ class CategoryEncoding(Layer):
"""

def __init__(self, num_tokens=None, output_mode="multi_hot", **kwargs):
if not tf.available:
raise ImportError(
"Layer CategoryEncoding requires TensorFlow. "
"Install it via `pip install tensorflow`."
)

super().__init__(**kwargs)

# Support deprecated names for output_modes.
if output_mode == "binary":
output_mode = "multi_hot"

# 'output_mode' must be one of ("count", "one_hot", "multi_hot")
if output_mode not in ("count", "one_hot", "multi_hot"):
raise ValueError(f"Unknown arg for output_mode: {output_mode}")

if output_mode == 'multi_hot':
self.output_function = backend.nn.multi_hot
elif output_mode == 'count':
self.output_function = backend.nn.count
elif output_mode == 'one_hot':
self.output_function = backend.nn.one_hot

if num_tokens is None:
raise ValueError(
"num_tokens must be set to use this layer. If the "
Expand All @@ -111,17 +113,11 @@ def __init__(self, num_tokens=None, output_mode="multi_hot", **kwargs):
)
self.num_tokens = num_tokens
self.output_mode = output_mode

self.layer = tf.keras.layers.CategoryEncoding(
num_tokens=num_tokens,
output_mode=output_mode,
**kwargs,
)
self._allow_non_tensor_positional_args = True
self._convert_input_args = False

def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
return tuple(input_shape + (self.num_classes,))

def get_config(self):
config = {
Expand All @@ -132,7 +128,7 @@ def get_config(self):
return {**base_config, **config}

def call(self, inputs):
outputs = self.layer.call(inputs)
outputs = self.output_function(inputs, self.num_tokens, dtype=self.dtype)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
Expand Down
24 changes: 0 additions & 24 deletions keras_core/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,27 +1089,3 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
return backend.nn.sparse_categorical_crossentropy(
target, output, from_logits=from_logits, axis=axis
)


class CategoryEncoding(Operation):
def __init__(self, num_classes, name=None):
super().__init__(name)
self.num_classes = num_classes

def call(self, inputs):
return backend.nn.one_hot(inputs, self.num_classes)

def compute_output_spec(self, inputs):
return KerasTensor(
inputs.shape + (self.num_classes,), dtype=inputs.dtype
)


@keras_core_export(
["keras_core.ops.category_encoding", "keras_core.ops.nn.category_encoding"]
)
def category_encoding(inputs, num_classes=None):
# TODO: add docstring
if any_symbolic_tensors((inputs,)):
return CategoryEncoding(num_classes=num_classes).symbolic_call(inputs)
return backend.nn.one_hot(inputs, num_classes)

0 comments on commit 8f4f857

Please sign in to comment.