Skip to content

Commit

Permalink
PR #672: Remove hardcoded channels_last in efficientnet_v2
Browse files Browse the repository at this point in the history
Imported from GitHub PR #672

Replace hard-coded channels_last in efficientnet_v2 with the backend data format
Copybara import of the project:

--
214fc39 by Richard Swanson <[email protected]>:

 Remove hardcoded channels_last in efficientnet_v2 and replace with backend data format

--
c52e562 by Richard Swanson <[email protected]>:

Add channels first testing for applications_test and cleanup testing script

--
faa997d by Richard Swanson <[email protected]>:

Fix failing channels first tests for efficientnet and mobilenet_v3

--
cf9f492 by Richard Swanson <[email protected]>:

Fix code formatting

Merging this change closes #672

FUTURE_COPYBARA_INTEGRATE_REVIEW=#672 from Inquisitive-ME:fix-efficientnet_v2-channels-first cf9f492
PiperOrigin-RevId: 580748697
  • Loading branch information
Inquisitive-ME authored and tensorflower-gardener committed Nov 9, 2023
1 parent 4fa8b74 commit eba6aee
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 51 deletions.
145 changes: 103 additions & 42 deletions tf_keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@

MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST

MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "NASNet", "RegNetX", "RegNetY"]
# Add each data format for each model
test_parameters_with_image_data_format = [
(
"{}_{}".format(model[0].__name__, image_data_format),
*model,
image_data_format,
)
for image_data_format in ["channels_first", "channels_last"]
for model in MODEL_LIST
]

# Parameters for loading weights for MobileNetV3.
# (class, alpha, minimalistic, include_top)
MOBILENET_V3_FOR_WEIGHTS = [
Expand All @@ -138,7 +150,16 @@


class ApplicationsTest(tf.test.TestCase, parameterized.TestCase):
def assertShapeEqual(self, shape1, shape2):
@classmethod
def setUpClass(cls):
cls.original_image_data_format = backend.image_data_format()

@classmethod
def tearDownClass(cls):
backend.set_image_data_format(cls.original_image_data_format)

@classmethod
def assertShapeEqual(cls, shape1, shape2):
if len(shape1) != len(shape2):
raise AssertionError(
f"Shapes are different rank: {shape1} vs {shape2}"
Expand All @@ -147,8 +168,27 @@ def assertShapeEqual(self, shape1, shape2):
if v1 != v2:
raise AssertionError(f"Shapes differ: {shape1} vs {shape2}")

@parameterized.parameters(*MODEL_LIST)
def test_application_base(self, app, _):
def skip_if_invalid_image_data_format_for_model(
self, app, image_data_format
):
does_not_support_channels_first = any(
[
unsupported_name.lower() in app.__name__.lower()
for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST
]
)
if (
image_data_format == "channels_first"
and does_not_support_channels_first
):
self.skipTest(
"{} does not support channels first".format(app.__name__)
)

@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_base(self, app, _, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
# Can be instantiated with default arguments
model = app(weights=None)
# Can be serialized and deserialized
Expand All @@ -162,36 +202,55 @@ def test_application_base(self, app, _):
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
backend.clear_session()

@parameterized.parameters(*MODEL_LIST)
def test_application_notop(self, app, last_dim):
@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_notop(self, app, last_dim, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
if image_data_format == "channels_first":
input_shape = (3, None, None)
correct_output_shape = (None, last_dim, None, None)
channels_axis = 1
else:
input_shape = (None, None, 3)
correct_output_shape = (None, None, None, last_dim)
channels_axis = -1

if "NASNet" in app.__name__:
only_check_last_dim = True
else:
only_check_last_dim = False
output_shape = _get_output_shape(
lambda: app(weights=None, include_top=False)
)
output_shape = app(
weights=None, include_top=False, input_shape=input_shape
).output_shape
if only_check_last_dim:
self.assertEqual(output_shape[-1], last_dim)
self.assertEqual(output_shape[channels_axis], last_dim)
else:
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
self.assertShapeEqual(output_shape, correct_output_shape)
backend.clear_session()

@parameterized.parameters(*MODEL_LIST)
def test_application_notop_custom_input_shape(self, app, last_dim):
output_shape = _get_output_shape(
lambda: app(
weights="imagenet", include_top=False, input_shape=(224, 224, 3)
)
)
@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_notop_custom_input_shape(
self, app, last_dim, image_data_format
):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
if image_data_format == "channels_first":
input_shape = (3, 224, 224)
channels_axis = 1
else:
input_shape = (224, 224, 3)
channels_axis = -1
output_shape = app(
weights="imagenet", include_top=False, input_shape=input_shape
).output_shape

self.assertEqual(output_shape[-1], last_dim)
self.assertEqual(output_shape[channels_axis], last_dim)

@parameterized.parameters(MODEL_LIST)
def test_application_pooling(self, app, last_dim):
output_shape = _get_output_shape(
lambda: app(weights=None, include_top=False, pooling="avg")
)
output_shape = app(
weights=None, include_top=False, pooling="avg"
).output_shape
self.assertShapeEqual(output_shape, (None, last_dim))

@parameterized.parameters(MODEL_LIST)
Expand All @@ -204,30 +263,34 @@ def test_application_classifier_activation(self, app, _):
last_layer_act = model.layers[-1].activation.__name__
self.assertEqual(last_layer_act, "softmax")

@parameterized.parameters(*MODEL_LIST_NO_NASNET)
def test_application_variable_input_channels(self, app, last_dim):
@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_variable_input_channels(
self, app, last_dim, image_data_format
):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
if backend.image_data_format() == "channels_first":
input_shape = (1, None, None)
correct_output_shape = (None, last_dim, None, None)
else:
input_shape = (None, None, 1)
output_shape = _get_output_shape(
lambda: app(
weights=None, include_top=False, input_shape=input_shape
)
)
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
correct_output_shape = (None, None, None, last_dim)
output_shape = app(
weights=None, include_top=False, input_shape=input_shape
).output_shape

self.assertShapeEqual(output_shape, correct_output_shape)
backend.clear_session()

if backend.image_data_format() == "channels_first":
input_shape = (4, None, None)
else:
input_shape = (None, None, 4)
output_shape = _get_output_shape(
lambda: app(
weights=None, include_top=False, input_shape=input_shape
)
)
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
output_shape = app(
weights=None, include_top=False, input_shape=input_shape
).output_shape

self.assertShapeEqual(output_shape, correct_output_shape)
backend.clear_session()

@parameterized.parameters(*MOBILENET_V3_FOR_WEIGHTS)
Expand All @@ -242,9 +305,12 @@ def test_mobilenet_v3_load_weights(
include_top=include_top,
)

@parameterized.parameters(MODEL_LIST)
@parameterized.named_parameters(test_parameters_with_image_data_format)
@test_utils.run_v2_only
def test_model_checkpoint(self, app, _):
def test_model_checkpoint(self, app, _, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)

model = app(weights=None)

checkpoint = tf.train.Checkpoint(model=model)
Expand All @@ -256,10 +322,5 @@ def test_model_checkpoint(self, app, _):
checkpoint_manager.save(checkpoint_number=1)


def _get_output_shape(model_fn):
model = model_fn()
return model.output_shape


if __name__ == "__main__":
tf.test.main()
14 changes: 11 additions & 3 deletions tf_keras/applications/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,17 @@ def round_repeats(repeats):
# original implementation.
# See https://github.com/tensorflow/tensorflow/issues/49930 for more
# details
x = layers.Rescaling(
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB]
)(x)
if backend.image_data_format() == "channels_first":
shape_for_multiply = [1, 3, 1, 1]
else:
shape_for_multiply = [1, 1, 1, 3]
x = tf.math.multiply(
x,
tf.reshape(
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB],
shape_for_multiply,
),
)

x = layers.ZeroPadding2D(
padding=imagenet_utils.correct_pad(x, 3), name="stem_conv_pad"
Expand Down
10 changes: 5 additions & 5 deletions tf_keras/applications/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def apply(inputs):
strides=1,
kernel_initializer=CONV_KERNEL_INITIALIZER,
padding="same",
data_format="channels_last",
data_format=backend.image_data_format(),
use_bias=False,
name=name + "expand_conv",
)(inputs)
Expand All @@ -677,7 +677,7 @@ def apply(inputs):
strides=strides,
depthwise_initializer=CONV_KERNEL_INITIALIZER,
padding="same",
data_format="channels_last",
data_format=backend.image_data_format(),
use_bias=False,
name=name + "dwconv2",
)(x)
Expand Down Expand Up @@ -722,7 +722,7 @@ def apply(inputs):
strides=1,
kernel_initializer=CONV_KERNEL_INITIALIZER,
padding="same",
data_format="channels_last",
data_format=backend.image_data_format(),
use_bias=False,
name=name + "project_conv",
)(x)
Expand Down Expand Up @@ -771,7 +771,7 @@ def apply(inputs):
kernel_size=kernel_size,
strides=strides,
kernel_initializer=CONV_KERNEL_INITIALIZER,
data_format="channels_last",
data_format=backend.image_data_format(),
padding="same",
use_bias=False,
name=name + "expand_conv",
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def EfficientNetV2(
strides=1,
kernel_initializer=CONV_KERNEL_INITIALIZER,
padding="same",
data_format="channels_last",
data_format=backend.image_data_format(),
use_bias=False,
name="top_conv",
)(x)
Expand Down
5 changes: 4 additions & 1 deletion tf_keras/applications/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def MobileNetV3(
input_shape = (cols, rows, 3)
# If input_shape is None and input_tensor is None using standard shape
if input_shape is None and input_tensor is None:
input_shape = (None, None, 3)
if backend.image_data_format() == "channels_last":
input_shape = (None, None, 3)
else:
input_shape = (3, None, None)

if backend.image_data_format() == "channels_last":
row_axis, col_axis = (0, 1)
Expand Down

0 comments on commit eba6aee

Please sign in to comment.