Skip to content

Commit

Permalink
Fix testing errors cause by recent TF updates.
Browse files Browse the repository at this point in the history
- Remove exact check for metric_embedding config dict. This class is a
  thin wrapper around the layers.Dense class and doesn't take any custom
  args.
- Update the deserialize from identifier functions to raise the
  ValueError if we are not able to cover the identifier to the target
  object type.
  • Loading branch information
owenvallis committed May 6, 2024
1 parent d36d6be commit 3fb19ac
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 38 deletions.
11 changes: 6 additions & 5 deletions tensorflow_similarity/distances/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ def get(identifier) -> Distance:
Raises:
ValueError: If `identifier` cannot be interpreted.
"""
if isinstance(identifier, Distance):
return identifier
elif isinstance(identifier, dict):
return deserialize(identifier)
if isinstance(identifier, dict):
identifier = deserialize(identifier)
elif isinstance(identifier, str):
config = {"class_name": str(identifier), "config": {}}
return deserialize(config)
identifier = deserialize(config)

if isinstance(identifier, Distance):
return identifier
else:
raise ValueError("Could not interpret search identifier: {}".format(identifier))
11 changes: 6 additions & 5 deletions tensorflow_similarity/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ def get(identifier, **kwargs) -> Search:
Raises:
ValueError: If `identifier` cannot be interpreted.
"""
if isinstance(identifier, Search):
return identifier
elif isinstance(identifier, dict):
return deserialize(identifier)
if isinstance(identifier, dict):
identifier = deserialize(identifier)
elif isinstance(identifier, str):
config = {"class_name": str(identifier), "config": kwargs}
return deserialize(config)
identifier = deserialize(config)

if isinstance(identifier, Search):
return identifier
else:
raise ValueError("Could not interpret search identifier: {}".format(identifier))
11 changes: 6 additions & 5 deletions tensorflow_similarity/stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ def get(identifier) -> Store:
Raises:
ValueError: If `identifier` cannot be interpreted.
"""
if isinstance(identifier, Store):
return identifier
elif isinstance(identifier, dict):
return deserialize(identifier)
if isinstance(identifier, dict):
identifier = deserialize(identifier)
elif isinstance(identifier, str):
config = {"class_name": str(identifier), "config": {}}
return deserialize(config)
identifier = deserialize(config)

if isinstance(identifier, Store):
return identifier
else:
raise ValueError("Could not interpret Store identifier: {}".format(identifier))
24 changes: 1 addition & 23 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import tensorflow as tf
from tensorflow.keras import layers

from tensorflow_similarity.layers import (
GeneralizedMeanPooling1D,
Expand Down Expand Up @@ -160,29 +161,6 @@ def test_metric_embedding(self):
expected_result = tf.constant([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]])
self.assertAllClose(result, expected_result, rtol=1e-06)

def test_metric_embedding_get_config(self):
me_layer = MetricEmbedding(32)
config = me_layer.get_config()
expected_config = {
"name": "metric_embedding",
"trainable": True,
"dtype": "float32",
"units": 32,
"activation": "linear",
"use_bias": True,
"kernel_initializer": {
"class_name": "GlorotUniform",
"config": {"seed": None},
},
"bias_initializer": {"class_name": "Zeros", "config": {}},
"kernel_regularizer": None,
"bias_regularizer": None,
"activity_regularizer": None,
"kernel_constraint": None,
"bias_constraint": None,
}
self.assertEqual(expected_config, config)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 3fb19ac

Please sign in to comment.