Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Commit

Permalink
Removes reliance on engine Keras module, in preparation for Keras 2…
Browse files Browse the repository at this point in the history
….2.1. Maintains compatibility with Keras 2.2.0.
  • Loading branch information
fchollet committed Jul 27, 2018
1 parent d0c13ac commit 85c4954
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 12 deletions.
7 changes: 6 additions & 1 deletion keras_applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
_KERAS_UTILS = None


def set_keras_submodules(backend, engine, layers, models, utils):
def set_keras_submodules(backend=None,
engine=None,
layers=None,
models=None,
utils=None):
# TODO: remove `engine` argument after release of Keras 2.2.1.
global _KERAS_BACKEND
global _KERAS_ENGINE
global _KERAS_LAYERS
Expand Down
6 changes: 5 additions & 1 deletion keras_applications/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def DenseNet(blocks,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
6 changes: 5 additions & 1 deletion keras_applications/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,11 @@ def InceptionResNetV2(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
6 changes: 5 additions & 1 deletion keras_applications/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,11 @@ def InceptionV3(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
Expand Down
6 changes: 5 additions & 1 deletion keras_applications/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,11 @@ def MobileNet(input_shape=None,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
9 changes: 7 additions & 2 deletions keras_applications/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@
from .imagenet_utils import decode_predictions
from .imagenet_utils import _obtain_input_shape

if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs

# TODO Change path to v1.1
BASE_WEIGHT_PATH = ('https://github.com/JonathanCMitchell/mobilenet_v2_keras/'
'releases/download/v1.1/')
Expand Down Expand Up @@ -213,7 +218,7 @@ def MobileNetV2(input_shape=None,
except ValueError:
try:
is_input_t_tensor = backend.is_keras_tensor(
engine.get_source_inputs(input_tensor))
get_source_inputs(input_tensor))
except ValueError:
raise ValueError('input_tensor: ', input_tensor,
'is not type input_tensor')
Expand Down Expand Up @@ -417,7 +422,7 @@ def MobileNetV2(input_shape=None,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
6 changes: 5 additions & 1 deletion keras_applications/nasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,11 @@ def NASNet(input_shape=None,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
6 changes: 5 additions & 1 deletion keras_applications/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ def ResNet50(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
Expand Down
6 changes: 5 additions & 1 deletion keras_applications/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ def VGG16(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
Expand Down
6 changes: 5 additions & 1 deletion keras_applications/vgg19.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def VGG19(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
Expand Down
6 changes: 5 additions & 1 deletion keras_applications/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ def Xception(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = engine.get_source_inputs(input_tensor)
if hasattr(keras_utils, 'get_source_inputs'):
get_source_inputs = keras_utils.get_source_inputs
else:
get_source_inputs = engine.get_source_inputs
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
Expand Down

0 comments on commit 85c4954

Please sign in to comment.