Skip to content

Commit

Permalink
Add test to compare the vectorized and loop version of get_all_triple…
Browse files Browse the repository at this point in the history
…ts_indices
  • Loading branch information
KevinMusgrave committed Apr 1, 2024
1 parent cfafd3b commit ef65345
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.4.1"
__version__ = "2.5.0"
18 changes: 9 additions & 9 deletions src/pytorch_metric_learning/losses/manifold_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
if self.lambdaC != np.inf:
F = F[:N, N:]
loss_int = F - F[torch.arange(N), meta_classes].view(-1, 1) + self.margin
loss_int[
torch.arange(N), meta_classes
] = -np.inf # This way avoid numerical cancellation happening # NoQA
loss_int[torch.arange(N), meta_classes] = (
-np.inf
) # This way avoid numerical cancellation happening # NoQA
# instead with subtraction of margin term # NoQA
loss_int[
loss_int < 0
] = -np.inf # This way no loss for positive correlation with own proxy
loss_int[loss_int < 0] = (
-np.inf
) # This way no loss for positive correlation with own proxy

loss_int = torch.exp(loss_int)
loss_int = torch.log(1 + torch.sum(loss_int, dim=1))
Expand All @@ -106,9 +106,9 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
F_e, F_p.unsqueeze(1), dim=-1
).t()
loss_ctx += -loss_ctx[torch.arange(N), meta_classes].view(-1, 1) + self.margin
loss_ctx[
torch.arange(N), meta_classes
] = -np.inf # This way avoid numerical cancellation happening # NoQA
loss_ctx[torch.arange(N), meta_classes] = (
-np.inf
) # This way avoid numerical cancellation happening # NoQA
# instead with subtraction of margin term # NoQA
loss_ctx[loss_ctx < 0] = -np.inf

Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_metric_learning/testers/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,10 @@ def test(
query_split_name,
reference_split_names,
)
self.end_of_testing_hook(self) if self.end_of_testing_hook else c_f.LOGGER.info(
self.all_accuracies
(
self.end_of_testing_hook(self)
if self.end_of_testing_hook
else c_f.LOGGER.info(self.all_accuracies)
)
del self.embeddings_and_labels
return self.all_accuracies
30 changes: 22 additions & 8 deletions src/pytorch_metric_learning/utils/loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,22 @@ def neg_pairs_from_tuple(indices_tuple):
def get_all_triplets_indices(labels, ref_labels=None):
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)

if (all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
< torch.iinfo(torch.int32).max):
if (
all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
< torch.iinfo(torch.int32).max
):
# torch.nonzero is not supported for tensors with more than INT_MAX elements
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
return torch.where(triplets)
return get_all_triplets_indices_vectorized_method(all_matches, all_diffs)

return get_all_triplets_indices_loop_method(labels, all_matches, all_diffs)


def get_all_triplets_indices_vectorized_method(all_matches, all_diffs):
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
return torch.where(triplets)


def get_all_triplets_indices_loop_method(labels, all_matches, all_diffs):
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()

# Find anchors with at least a positive and a negative
Expand All @@ -101,9 +111,11 @@ def get_all_triplets_indices(labels, ref_labels=None):

# No triplets found
if len(indices) == 0:
return (torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype))
return (
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
)

# Compute all triplets
anchors = []
Expand All @@ -116,7 +128,9 @@ def get_all_triplets_indices(labels, ref_labels=None):
nm = len(matches)
matches = matches.repeat_interleave(nd)
diffs = diffs.repeat(nm)
anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device))
anchors.append(
torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device)
)
positives.append(matches)
negatives.append(diffs)
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)
Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_calculate_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def test_accuracy_calculator(self):
"query_labels": query_labels,
"label_counts": label_counts,
"knn_labels": knn_labels,
"not_lone_query_mask": torch.ones(6, dtype=torch.bool)
if i == 0
else torch.zeros(6, dtype=torch.bool),
"not_lone_query_mask": (
torch.ones(6, dtype=torch.bool)
if i == 0
else torch.zeros(6, dtype=torch.bool)
),
}

function_dict = AC.get_function_dict()
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,29 @@ def test_remove_self_comparisons_small_ref(self):
self.assertTrue(torch.equal(a1, correct_a1))
self.assertTrue(torch.equal(p, correct_p))

def test_get_all_triplets_indices(self):
torch.manual_seed(920)
for dtype in TEST_DTYPES:
for batch_size in [32, 256, 512]:
for ref_labels in [None, torch.randint(0, 5, size=(batch_size // 2,))]:
labels = torch.randint(0, 5, size=(batch_size,))

a, p, n = lmu.get_all_triplets_indices(labels, ref_labels)
matches, diffs = lmu.get_matches_and_diffs(labels, ref_labels)

a2, p2, n2 = lmu.get_all_triplets_indices_vectorized_method(
matches, diffs
)
a3, p3, n3 = lmu.get_all_triplets_indices_loop_method(
labels, matches, diffs
)
self.assertTrue(
(a == a2).all() and (p == p2).all() and (n == n2).all()
)
self.assertTrue(
(a == a3).all() and (p == p3).all() and (n == n3).all()
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ef65345

Please sign in to comment.