Skip to content

Commit

Permalink
Hacky fix for dictionary output with tf 2.14
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Sep 20, 2023
1 parent a465816 commit 4a040bb
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions keras_core/ops/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
from keras_core.backend.config import backend
from keras_core.ops.operation import Operation
from keras_core.utils.nest import pack_sequence_as

Expand Down Expand Up @@ -46,10 +47,20 @@ class Function(Operation):
def __init__(self, inputs, outputs, name=None):
super().__init__(name=name)

if backend() == "tensorflow":
# Temporary work around for https://github.com/keras-team/keras-core/issues/931
# This stop tensorflow from wrapping tf.function output in a
# _DictWrapper object.
_self_setattr_tracking = getattr(

Check warning on line 54 in keras_core/ops/function.py

View check run for this annotation

Codecov / codecov/patch

keras_core/ops/function.py#L54

Added line #L54 was not covered by tests
self, "_self_setattr_tracking", True
)
self._self_setattr_tracking = False

Check warning on line 57 in keras_core/ops/function.py

View check run for this annotation

Codecov / codecov/patch

keras_core/ops/function.py#L57

Added line #L57 was not covered by tests
self._inputs_struct = tree.map_structure(lambda x: x, inputs)
self._outputs_struct = tree.map_structure(lambda x: x, outputs)
self._inputs = tree.flatten(inputs)
self._outputs = tree.flatten(outputs)
if backend() == "tensorflow":
self._self_setattr_tracking = _self_setattr_tracking

Check warning on line 63 in keras_core/ops/function.py

View check run for this annotation

Codecov / codecov/patch

keras_core/ops/function.py#L63

Added line #L63 was not covered by tests

(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
self._inputs, self._outputs
Expand Down

0 comments on commit 4a040bb

Please sign in to comment.