Skip to content

Commit

Permalink
Add the logic to assert and remove files after running test (keras-te…
Browse files Browse the repository at this point in the history
…am#20190)

* Add the logic to assert and remove files after running test

* Add the logic to assert and remove files after running test

* delete method for removing files

* declare variable named file_name

* declare variable named file_name

* not to remove plot_model method calling
  • Loading branch information
shashaka authored Sep 2, 2024
1 parent 7689bcc commit f693113
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions integration_tests/model_visualization_test.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,92 @@
from pathlib import Path

import keras
from keras.src.utils import plot_model


def plot_sequential_model():
def assert_file_exists(path):
assert Path(path).is_file(), "File does not exist"


def test_plot_sequential_model():
model = keras.Sequential(
[
keras.Input((3,)),
keras.layers.Dense(4, activation="relu"),
keras.layers.Dense(1, activation="sigmoid"),
]
)
plot_model(model, "sequential.png")
plot_model(model, "sequential-show_shapes.png", show_shapes=True)
file_name = "sequential.png"
plot_model(model, file_name)
assert_file_exists(file_name)

file_name = "sequential-show_shapes.png"
plot_model(model, file_name, show_shapes=True)
assert_file_exists(file_name)

file_name = "sequential-show_shapes-show_dtype.png"
plot_model(
model,
"sequential-show_shapes-show_dtype.png",
file_name,
show_shapes=True,
show_dtype=True,
)
assert_file_exists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names.png"
plot_model(
model,
"sequential-show_shapes-show_dtype-show_layer_names.png",
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
)
assert_file_exists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501
plot_model(
model,
"sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
)
assert_file_exists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501
plot_model(
model,
"sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)

file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501
plot_model(
model,
"sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png", # noqa: E501
file_name,
show_shapes=True,
show_dtype=True,
show_layer_names=True,
show_layer_activations=True,
show_trainable=True,
rankdir="LR",
)
assert_file_exists(file_name)

file_name = "sequential-show_layer_activations-show_trainable.png"
plot_model(
model,
"sequential-show_layer_activations-show_trainable.png",
file_name,
show_layer_activations=True,
show_trainable=True,
)
assert_file_exists(file_name)


def plot_functional_model():
Expand Down Expand Up @@ -329,7 +358,7 @@ def call(self, xs):


if __name__ == "__main__":
plot_sequential_model()
test_plot_sequential_model()
plot_functional_model()
plot_subclassed_model()
plot_nested_functional_model()
Expand Down

0 comments on commit f693113

Please sign in to comment.