diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 54499df3..50062f87 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.4.1" +__version__ = "2.5.0" diff --git a/src/pytorch_metric_learning/losses/manifold_loss.py b/src/pytorch_metric_learning/losses/manifold_loss.py index 8622ddfd..a0b4460d 100644 --- a/src/pytorch_metric_learning/losses/manifold_loss.py +++ b/src/pytorch_metric_learning/losses/manifold_loss.py @@ -90,13 +90,13 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): if self.lambdaC != np.inf: F = F[:N, N:] loss_int = F - F[torch.arange(N), meta_classes].view(-1, 1) + self.margin - loss_int[ - torch.arange(N), meta_classes - ] = -np.inf # This way avoid numerical cancellation happening # NoQA + loss_int[torch.arange(N), meta_classes] = ( + -np.inf + ) # This way avoid numerical cancellation happening # NoQA # instead with subtraction of margin term # NoQA - loss_int[ - loss_int < 0 - ] = -np.inf # This way no loss for positive correlation with own proxy + loss_int[loss_int < 0] = ( + -np.inf + ) # This way no loss for positive correlation with own proxy loss_int = torch.exp(loss_int) loss_int = torch.log(1 + torch.sum(loss_int, dim=1)) @@ -106,9 +106,9 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): F_e, F_p.unsqueeze(1), dim=-1 ).t() loss_ctx += -loss_ctx[torch.arange(N), meta_classes].view(-1, 1) + self.margin - loss_ctx[ - torch.arange(N), meta_classes - ] = -np.inf # This way avoid numerical cancellation happening # NoQA + loss_ctx[torch.arange(N), meta_classes] = ( + -np.inf + ) # This way avoid numerical cancellation happening # NoQA # instead with subtraction of margin term # NoQA loss_ctx[loss_ctx < 0] = -np.inf diff --git a/src/pytorch_metric_learning/testers/base_tester.py b/src/pytorch_metric_learning/testers/base_tester.py index d6a8e8f8..9813f15e 100644 --- a/src/pytorch_metric_learning/testers/base_tester.py +++ b/src/pytorch_metric_learning/testers/base_tester.py @@ -306,8 +306,10 @@ def test( query_split_name, reference_split_names, ) - self.end_of_testing_hook(self) if self.end_of_testing_hook else c_f.LOGGER.info( - self.all_accuracies + ( + self.end_of_testing_hook(self) + if self.end_of_testing_hook + else c_f.LOGGER.info(self.all_accuracies) ) del self.embeddings_and_labels return self.all_accuracies diff --git a/src/pytorch_metric_learning/utils/loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/loss_and_miner_utils.py index a6d90fba..4f2337a9 100644 --- a/src/pytorch_metric_learning/utils/loss_and_miner_utils.py +++ b/src/pytorch_metric_learning/utils/loss_and_miner_utils.py @@ -87,12 +87,22 @@ def neg_pairs_from_tuple(indices_tuple): def get_all_triplets_indices(labels, ref_labels=None): all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels) - if (all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1] - < torch.iinfo(torch.int32).max): + if ( + all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1] + < torch.iinfo(torch.int32).max + ): # torch.nonzero is not supported for tensors with more than INT_MAX elements - triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1) - return torch.where(triplets) + return get_all_triplets_indices_vectorized_method(all_matches, all_diffs) + return get_all_triplets_indices_loop_method(labels, all_matches, all_diffs) + + +def get_all_triplets_indices_vectorized_method(all_matches, all_diffs): + triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1) + return torch.where(triplets) + + +def get_all_triplets_indices_loop_method(labels, all_matches, all_diffs): all_matches, all_diffs = all_matches.bool(), all_diffs.bool() # Find anchors with at least a positive and a negative @@ -101,9 +111,11 @@ def get_all_triplets_indices(labels, ref_labels=None): # No triplets found if len(indices) == 0: - return (torch.tensor([], device=labels.device, dtype=labels.dtype), - torch.tensor([], device=labels.device, dtype=labels.dtype), - torch.tensor([], device=labels.device, dtype=labels.dtype)) + return ( + torch.tensor([], device=labels.device, dtype=labels.dtype), + torch.tensor([], device=labels.device, dtype=labels.dtype), + torch.tensor([], device=labels.device, dtype=labels.dtype), + ) # Compute all triplets anchors = [] @@ -116,7 +128,9 @@ def get_all_triplets_indices(labels, ref_labels=None): nm = len(matches) matches = matches.repeat_interleave(nd) diffs = diffs.repeat(nm) - anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device)) + anchors.append( + torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device) + ) positives.append(matches) negatives.append(diffs) return torch.cat(anchors), torch.cat(positives), torch.cat(negatives) diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index ff83b989..5e26a58c 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -67,9 +67,11 @@ def test_accuracy_calculator(self): "query_labels": query_labels, "label_counts": label_counts, "knn_labels": knn_labels, - "not_lone_query_mask": torch.ones(6, dtype=torch.bool) - if i == 0 - else torch.zeros(6, dtype=torch.bool), + "not_lone_query_mask": ( + torch.ones(6, dtype=torch.bool) + if i == 0 + else torch.zeros(6, dtype=torch.bool) + ), } function_dict = AC.get_function_dict() diff --git a/tests/utils/test_loss_and_miner_utils.py b/tests/utils/test_loss_and_miner_utils.py index 6bd559b3..c5137aad 100644 --- a/tests/utils/test_loss_and_miner_utils.py +++ b/tests/utils/test_loss_and_miner_utils.py @@ -291,6 +291,29 @@ def test_remove_self_comparisons_small_ref(self): self.assertTrue(torch.equal(a1, correct_a1)) self.assertTrue(torch.equal(p, correct_p)) + def test_get_all_triplets_indices(self): + torch.manual_seed(920) + for dtype in TEST_DTYPES: + for batch_size in [32, 256, 512]: + for ref_labels in [None, torch.randint(0, 5, size=(batch_size // 2,))]: + labels = torch.randint(0, 5, size=(batch_size,)) + + a, p, n = lmu.get_all_triplets_indices(labels, ref_labels) + matches, diffs = lmu.get_matches_and_diffs(labels, ref_labels) + + a2, p2, n2 = lmu.get_all_triplets_indices_vectorized_method( + matches, diffs + ) + a3, p3, n3 = lmu.get_all_triplets_indices_loop_method( + labels, matches, diffs + ) + self.assertTrue( + (a == a2).all() and (p == p2).all() and (n == n2).all() + ) + self.assertTrue( + (a == a3).all() and (p == p3).all() and (n == n3).all() + ) + if __name__ == "__main__": unittest.main()