Skip to content

Commit

Permalink
Add assert statement to model_visualization_test (keras-team#20201)
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka authored Sep 3, 2024
1 parent f693113 commit fa6be07
Showing 1 changed file with 129 additions and 38 deletions.
167 changes: 129 additions & 38 deletions integration_tests/model_visualization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_plot_sequential_model():
assert_file_exists(file_name)


def plot_functional_model():
def test_plot_functional_model():
inputs = keras.Input((3,))
x = keras.layers.Dense(4, activation="relu", trainable=False)(inputs)
residual = x
Expand All @@ -106,64 +106,93 @@ def plot_functional_model():
outputs = keras.layers.Dense(1, activation="sigmoid")(x)

model = keras.Model(inputs, outputs)
plot_model(model, "functional.png")
plot_model(model, "functional-show_shapes.png", show_shapes=True)

file_name = "functional.png"
plot_model(model, file_name)
assert_file_exists(file_name)

file_name = "functional-show_shapes.png"
plot_model(model, file_name, show_shapes=True)
assert_file_exists(file_name)

file_name = "functional-show_shapes-show_dtype.png"
plot_model(
model,
"functional-show_shapes-show_dtype.png",
file_name,
show_shapes=True,
show_dtype=True,
)
assert_file_exists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_names.png"
plot_model(
model,
"functional-show_shapes-show_dtype-show_layer_names.png",
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
)
assert_file_exists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_activations.png"
plot_model(
model,
"functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
)
assert_file_exists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
model,
"functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)

file_name = "functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
model,
"functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
rankdir="LR",
)
assert_file_exists(file_name)

file_name = "functional-show_layer_activations-show_trainable.png"
plot_model(
model,
"functional-show_layer_activations-show_trainable.png",
file_name,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)

file_name = (
"functional-show_shapes-show_layer_activations-show_trainable.png"
)
plot_model(
model,
"functional-show_shapes-show_layer_activations-show_trainable.png",
file_name,
show_shapes=True,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)


def plot_subclassed_model():
def test_plot_subclassed_model():
class MyModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -176,64 +205,92 @@ def call(self, x):
model = MyModel()
model.build((None, 3))

plot_model(model, "subclassed.png")
plot_model(model, "subclassed-show_shapes.png", show_shapes=True)
file_name = "subclassed.png"
plot_model(model, file_name)
assert_file_exists(file_name)

file_name = "subclassed-show_shapes.png"
plot_model(model, file_name, show_shapes=True)
assert_file_exists(file_name)

file_name = "subclassed-show_shapes-show_dtype.png"
plot_model(
model,
"subclassed-show_shapes-show_dtype.png",
file_name,
show_shapes=True,
show_dtype=True,
)
assert_file_exists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_names.png"
plot_model(
model,
"subclassed-show_shapes-show_dtype-show_layer_names.png",
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
)
assert_file_exists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_activations.png"
plot_model(
model,
"subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
)
assert_file_exists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
model,
"subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)

file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
model,
"subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
rankdir="LR",
)
assert_file_exists(file_name)

file_name = "subclassed-show_layer_activations-show_trainable.png"
plot_model(
model,
"subclassed-show_layer_activations-show_trainable.png",
file_name,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)

file_name = (
"subclassed-show_shapes-show_layer_activations-show_trainable.png"
)
plot_model(
model,
"subclassed-show_shapes-show_layer_activations-show_trainable.png",
file_name,
show_shapes=True,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)


def plot_nested_functional_model():
def test_plot_nested_functional_model():
inputs = keras.Input((3,))
x = keras.layers.Dense(4, activation="relu")(inputs)
x = keras.layers.Dense(4, activation="relu")(x)
Expand All @@ -254,50 +311,69 @@ def plot_nested_functional_model():
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)

plot_model(model, "nested-functional.png", expand_nested=True)
file_name = "nested-functional.png"
plot_model(model, file_name, expand_nested=True)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes.png"
plot_model(
model,
"nested-functional-show_shapes.png",
file_name,
show_shapes=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes-show_dtype.png"
plot_model(
model,
"nested-functional-show_shapes-show_dtype.png",
file_name,
show_shapes=True,
show_dtype=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names.png"
plot_model(
model,
"nested-functional-show_shapes-show_dtype-show_layer_names.png",
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501
plot_model(
model,
"nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
model,
"nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
model,
"nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
Expand All @@ -306,24 +382,31 @@ def plot_nested_functional_model():
rankdir="LR",
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_layer_activations-show_trainable.png"
plot_model(
model,
"nested-functional-show_layer_activations-show_trainable.png",
file_name,
show_layer_activations=True,
show_trainable=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "nested-functional-show_shapes-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
model,
"nested-functional-show_shapes-show_layer_activations-show_trainable.png", # noqa: E501
file_name,
show_shapes=True,
show_layer_activations=True,
show_trainable=True,
expand_nested=True,
)
assert_file_exists(file_name)


def plot_functional_model_with_splits_and_merges():
def test_plot_functional_model_with_splits_and_merges():
class SplitLayer(keras.Layer):
def call(self, x):
return list(keras.ops.split(x, 2, axis=1))
Expand All @@ -341,25 +424,33 @@ def call(self, xs):
outputs = ConcatLayer()([a, b])
model = keras.Model(inputs, outputs)

plot_model(model, "split-functional.png", expand_nested=True)
file_name = "split-functional.png"
plot_model(model, file_name, expand_nested=True)
assert_file_exists(file_name)

file_name = "split-functional-show_shapes.png"
plot_model(
model,
"split-functional-show_shapes.png",
file_name,
show_shapes=True,
expand_nested=True,
)
assert_file_exists(file_name)

file_name = "split-functional-show_shapes-show_dtype.png"
plot_model(
model,
"split-functional-show_shapes-show_dtype.png",
file_name,
show_shapes=True,
show_dtype=True,
expand_nested=True,
)
assert_file_exists(file_name)


if __name__ == "__main__":
test_plot_sequential_model()
plot_functional_model()
plot_subclassed_model()
plot_nested_functional_model()
plot_functional_model_with_splits_and_merges()
test_plot_functional_model()
test_plot_subclassed_model()
test_plot_nested_functional_model()
test_plot_functional_model_with_splits_and_merges()

0 comments on commit fa6be07

Please sign in to comment.