Skip to content

Commit

Permalink
Make it possible to run Keras Core without TensorFlow. (#475)
Browse files Browse the repository at this point in the history
* Make it possible to run Keras Core without TensorFlow.

* Update requirements
  • Loading branch information
fchollet authored Jul 14, 2023
1 parent e4dec5a commit 649f67f
Show file tree
Hide file tree
Showing 84 changed files with 823 additions and 457 deletions.
10 changes: 8 additions & 2 deletions examples/demo_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
from keras_core import optimizers

inputs = layers.Input((100,))
x = layers.Dense(128, activation="relu")(inputs)
x = layers.Dense(512, activation="relu")(inputs)
residual = x
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
x = layers.Dense(512, activation="relu")(x)
residual = x
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
outputs = layers.Dense(16)(x)
model = Model(inputs, outputs)
Expand Down
3 changes: 1 addition & 2 deletions keras_core/applications/convnext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from tensorflow.io import gfile

from keras_core import backend
from keras_core import initializers
Expand Down Expand Up @@ -390,7 +389,7 @@ def ConvNeXt(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/densenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -182,7 +180,7 @@ def DenseNet(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/efficientnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import copy
import math

from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -271,7 +269,7 @@ def EfficientNet(
if blocks_args == "default":
blocks_args = DEFAULT_BLOCKS_ARGS

if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/efficientnet_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import copy
import math

from tensorflow.io import gfile

from keras_core import backend
from keras_core import initializers
from keras_core import layers
Expand Down Expand Up @@ -892,7 +890,7 @@ def EfficientNetV2(
if blocks_args == "default":
blocks_args = DEFAULT_BLOCKS_ARGS[model_name]

if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/inception_resnet_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -94,7 +92,7 @@ def InceptionResNetV2(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/inception_v3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -97,7 +95,7 @@ def InceptionV3(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -107,7 +105,7 @@ def MobileNet(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), 'imagenet' "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/mobilenet_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -109,7 +107,7 @@ def MobileNetV2(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/mobilenet_v3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -165,7 +163,7 @@ def MobileNetV3(
classifier_activation="softmax",
include_preprocessing=True,
):
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/nasnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -111,7 +109,7 @@ def NASNet(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -107,7 +105,7 @@ def ResNet(
A Model instance.
"""

if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), 'imagenet' "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/vgg16.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -94,7 +92,7 @@ def VGG16(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), 'imagenet' "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/vgg19.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -94,7 +92,7 @@ def VGG19(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), 'imagenet' "
Expand Down
4 changes: 1 addition & 3 deletions keras_core/applications/xception.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.io import gfile

from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
Expand Down Expand Up @@ -92,7 +90,7 @@ def Xception(
Returns:
A model instance.
"""
if not (weights in {"imagenet", None} or gfile.exists(weights)):
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), 'imagenet' "
Expand Down
4 changes: 2 additions & 2 deletions keras_core/backend/common/keras_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tensorflow import nest
import tree

from keras_core.api_export import keras_core_export
from keras_core.utils.naming import auto_name
Expand Down Expand Up @@ -232,7 +232,7 @@ def __getitem__(self, key):
def any_symbolic_tensors(args=None, kwargs=None):
args = args or ()
kwargs = kwargs or {}
for x in nest.flatten((args, kwargs)):
for x in tree.flatten((args, kwargs)):
if isinstance(x, KerasTensor):
return True
return False
Expand Down
19 changes: 10 additions & 9 deletions keras_core/backend/jax/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow import nest
import tree

from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.nest import pack_sequence_as

DYNAMIC_SHAPES_OK = True

Expand Down Expand Up @@ -89,7 +90,7 @@ def index_all_ktensors(x):
return x

# Third, find out if there are dynamic shapes
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
maybe_symbolic_args, maybe_symbolic_kwargs = tree.map_structure(
index_all_ktensors, (maybe_symbolic_args, maybe_symbolic_kwargs)
)
none_count = 0
Expand All @@ -115,24 +116,24 @@ def wrapped_fn(*args, **kwargs):
jax_out = None
if none_count:
try:
ms_args_1, ms_kwargs_1 = nest.map_structure(
ms_args_1, ms_kwargs_1 = tree.map_structure(
lambda x: convert_keras_tensor_to_jax(x, fill_value=83),
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
_, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*ms_args_1, **ms_kwargs_1
)

ms_args_2, ms_kwargs_2 = nest.map_structure(
ms_args_2, ms_kwargs_2 = tree.map_structure(
lambda x: convert_keras_tensor_to_jax(x, fill_value=89),
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
_, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*ms_args_2, **ms_kwargs_2
)

flat_out_1 = nest.flatten(jax_out_1)
flat_out_2 = nest.flatten(jax_out_2)
flat_out_1 = tree.flatten(jax_out_1)
flat_out_2 = tree.flatten(jax_out_2)

flat_out = []
for x1, x2 in zip(flat_out_1, flat_out_2):
Expand All @@ -148,7 +149,7 @@ def wrapped_fn(*args, **kwargs):
)
else:
flat_out.append(x1)
jax_out = nest.pack_sequence_as(jax_out_1, flat_out)
jax_out = pack_sequence_as(jax_out_1, flat_out)
except:
# Errors can happen when the filled dimensions
# are not compatible with the function
Expand All @@ -162,7 +163,7 @@ def wrapped_fn(*args, **kwargs):
pass

if jax_out is None:
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
maybe_symbolic_args, maybe_symbolic_kwargs = tree.map_structure(
convert_keras_tensor_to_jax,
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
Expand All @@ -175,7 +176,7 @@ def convert_jax_spec_to_keras_tensor(x):
return KerasTensor(x.shape, x.dtype)
return x

output_shape = nest.map_structure(
output_shape = tree.map_structure(
convert_jax_spec_to_keras_tensor, jax_out
)
return output_shape
Expand Down
Loading

0 comments on commit 649f67f

Please sign in to comment.