Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.4.0 #677

Merged
merged 30 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dadb6d9
First implementation
domenicoMuscill0 Jul 22, 2023
113b430
Implemented DSM loss
domenicoMuscill0 Aug 5, 2023
60743f8
Implemented DSM loss
domenicoMuscill0 Aug 5, 2023
f421716
Merge branch 'KevinMusgrave:master' into dynamic-soft-margin-loss
domenicoMuscill0 Aug 20, 2023
6638b0a
Implement RLL
domenicoMuscill0 Aug 27, 2023
9805fbe
Edit loss docstring
domenicoMuscill0 Aug 27, 2023
125d9b0
format code
domenicoMuscill0 Aug 27, 2023
25f800f
Fix PNP loss to make it work with negatives without related positive …
Puzer Sep 12, 2023
663ba7a
Fix bug & add warning
domenicoMuscill0 Sep 19, 2023
0d66ca8
fix test_histogram_loss
GaetanLepage Oct 12, 2023
47ca9d0
Use pytorch 2.1 in github workflow
KevinMusgrave Oct 18, 2023
2f6548b
Bump tensorboard version
KevinMusgrave Oct 18, 2023
def8b17
Revert previous change
KevinMusgrave Oct 18, 2023
b08259e
Manually install six in workflow
KevinMusgrave Oct 18, 2023
67439e9
Bump python to 3.9 for workflows
KevinMusgrave Oct 18, 2023
ac19166
Versions compatible with python 3.9
KevinMusgrave Oct 18, 2023
3ed5f2b
Update base_test_workflow.yml
KevinMusgrave Oct 18, 2023
0e3fd1b
Update base_test_workflow.yml
KevinMusgrave Oct 18, 2023
6ee9c1d
Update base_test_workflow.yml
KevinMusgrave Oct 18, 2023
cdc83e1
Update base_test_workflow.yml
KevinMusgrave Oct 18, 2023
395cc0b
Update base_test_workflow.yml
KevinMusgrave Oct 18, 2023
a36ee78
Merge pull request #668 from GaetanLepage/master
KevinMusgrave Oct 18, 2023
6df3168
bug fix
domenicoMuscill0 Oct 18, 2023
5ba07ba
Added a test and fixed the denominator
KevinMusgrave Nov 11, 2023
dc772a8
Merge pull request #660 from Puzer/pnp_loss_nan_fix
KevinMusgrave Nov 11, 2023
6a6c201
test_dynamic_soft_margin_loss: skip float16, use seed
KevinMusgrave Dec 12, 2023
6284698
test_ranked_list_loss: add seed
KevinMusgrave Dec 12, 2023
9ce8249
Add basic docs
KevinMusgrave Dec 12, 2023
a69a551
Merge pull request #659 from domenicoMuscill0/dynamic-soft-margin-loss
KevinMusgrave Dec 12, 2023
649e110
Skip float16 for a couple of tests
KevinMusgrave Dec 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/workflows/base_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
pytorch-version: [1.6, 1.11]
torchvision-version: [0.7.0, 0.12.0]
with-collect-stats: [false]
exclude:
- pytorch-version: 1.6
torchvision-version: 0.12.0
- pytorch-version: 1.11
torchvision-version: 0.7.0
include:
- python-version: 3.8
pytorch-version: 1.6
torchvision-version: 0.7
- python-version: 3.9
pytorch-version: 2.1
torchvision-version: 0.16

steps:
- uses: actions/checkout@v2
Expand All @@ -34,6 +32,8 @@ jobs:
pip install .[with-hooks-cpu]
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install --upgrade protobuf==3.20.1
pip install six
pip install packaging
- name: Run unit tests
run: |
TEST_DTYPES=float32,float64 TEST_DEVICE=cpu WITH_COLLECT_STATS=${{ matrix.with-collect-stats }} python -m unittest discover -t . -s tests/${{ inputs.module-to-test }}
27 changes: 27 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,19 @@ The queue can be cleared like this:
loss_fn.reset_queue()
```

## DynamicSoftMarginLoss
[Learning Local Descriptors With a CDF-Based Dynamic Soft Margin](https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhang_Learning_Local_Descriptors_With_a_CDF-Based_Dynamic_Soft_Margin_ICCV_2019_paper.pdf)
```python
losses.DynamicSoftMarginLoss(min_val=-2.0, num_bins=10, momentum=0.01, **kwargs)
```

**Parameters**:

* **min_val**: minimum significative value for `d_pos - d_neg`
* **num_bins**: number of equally spaced bins for the partition of the interval `[min_val, ∞]`
* **momentum**: weight assigned to the histogram computed from the current batch


## FastAPLoss
[Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/papers/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.pdf){target=_blank}

Expand Down Expand Up @@ -969,6 +982,20 @@ loss_optimizer.step()

* **loss**: The loss per element in the batch, that results in a non zero exponent in the cross entropy expression. Reduction type is ```"element"```.

## RankedListLoss
[Ranked List Loss for Deep Metric Learning](https://arxiv.org/abs/1903.03238)
```python
losses.RankedListLoss(margin, Tn, imbalance=0.5, alpha=None, Tp=0, **kwargs)
```

**Parameters**:

* **margin** (float): margin between positive and negative set
* **imbalance** (float): tradeoff between positive and negative sets. As the name suggests this takes into account
the imbalance between positive and negative samples in the dataset
* **alpha** (float): smallest distance between negative points
* **Tp & Tn** (float): temperatures for, respectively, positive and negative pairs weighting.


## SelfSupervisedLoss

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.3.0"
__version__ = "2.4.0"
2 changes: 2 additions & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .contrastive_loss import ContrastiveLoss
from .cosface_loss import CosFaceLoss
from .cross_batch_memory import CrossBatchMemory
from .dynamic_soft_margin_loss import DynamicSoftMarginLoss
from .fast_ap_loss import FastAPLoss
from .generic_pair_loss import GenericPairLoss
from .histogram_loss import HistogramLoss
Expand All @@ -26,6 +27,7 @@
from .pnp_loss import PNPLoss
from .proxy_anchor_loss import ProxyAnchorLoss
from .proxy_losses import ProxyNCALoss
from .ranked_list_loss import RankedListLoss
from .self_supervised_loss import SelfSupervisedLoss
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
from .soft_triple_loss import SoftTripleLoss
Expand Down
125 changes: 125 additions & 0 deletions src/pytorch_metric_learning/losses/dynamic_soft_margin_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import numpy as np
import torch

from ..distances import LpDistance
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


def find_hard_negatives(dmat):
"""
a = A * P'
A: N * ndim
P: N * ndim

a1p1 a1p2 a1p3 a1p4 ...
a2p1 a2p2 a2p3 a2p4 ...
a3p1 a3p2 a3p3 a3p4 ...
a4p1 a4p2 a4p3 a4p4 ...
... ... ... ...
"""

pos = dmat.diag()
dmat.fill_diagonal_(np.inf)

min_a, _ = torch.min(dmat, dim=0)
min_p, _ = torch.min(dmat, dim=1)
neg = torch.min(min_a, min_p)
return pos, neg


class DynamicSoftMarginLoss(BaseMetricLossFunction):
r"""Loss function with dynamical margin parameter introduced in https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhang_Learning_Local_Descriptors_With_a_CDF-Based_Dynamic_Soft_Margin_ICCV_2019_paper.pdf

Args:
min_val: minimum significative value for `d_pos - d_neg`
num_bins: number of equally spaced bins for the partition of the interval [min_val, :math:`+\infty`]
momentum: weight assigned to the histogram computed from the current batch
"""

def __init__(self, min_val=-2.0, num_bins=10, momentum=0.01, **kwargs):
super().__init__(**kwargs)
c_f.assert_distance_type(self, LpDistance, normalize_embeddings=True, p=2)
self.min_val = min_val
self.num_bins = int(num_bins)
self.delta = 2 * abs(min_val) / num_bins
self.momentum = momentum
self.hist_ = torch.zeros((num_bins,))
self.add_to_recordable_attributes(list_of_names=["num_bins"], is_stat=False)

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
self.hist_ = c_f.to_device(
self.hist_, tensor=embeddings, dtype=embeddings.dtype
)

if labels is None:
loss = self.compute_loss_without_labels(
embeddings, labels, indices_tuple, ref_emb, ref_labels
)
else:
loss = self.compute_loss_with_labels(
embeddings, labels, indices_tuple, ref_emb, ref_labels
)

if len(loss) == 0:
return self.zero_losses()

self.update_histogram(loss)
loss = self.weigh_loss(loss)
loss = loss.mean()
return {
"loss": {
"losses": loss,
"indices": None,
"reduction_type": "already_reduced",
}
}

def compute_loss_without_labels(
self, embeddings, labels, indices_tuple, ref_emb, ref_labels
):
mat = self.distance(embeddings, ref_emb)
r, c = mat.size()

d_pos = torch.zeros(max(r, c))
d_pos = c_f.to_device(d_pos, tensor=embeddings, dtype=embeddings.dtype)
d_pos[: min(r, c)] = mat.diag()
mat.fill_diagonal_(np.inf)

min_a, min_p = torch.zeros(max(r, c)), torch.zeros(
max(r, c)
) # Check for unequal number of anchors and positives
min_a = c_f.to_device(min_a, tensor=embeddings, dtype=embeddings.dtype)
min_p = c_f.to_device(min_p, tensor=embeddings, dtype=embeddings.dtype)
min_a[:c], _ = torch.min(mat, dim=0)
min_p[:r], _ = torch.min(mat, dim=1)

d_neg = torch.min(min_a, min_p)
return d_pos - d_neg

def compute_loss_with_labels(
self, embeddings, labels, indices_tuple, ref_emb, ref_labels
):
anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(
indices_tuple, labels, ref_labels, t_per_anchor="all"
) # Use all instead of t_per_anchor=1 to be deterministic
mat = self.distance(embeddings, ref_emb)
d_pos, d_neg = mat[anchor_idx, positive_idx], mat[anchor_idx, negative_idx]
return d_pos - d_neg

def update_histogram(self, data):
idx, alpha = torch.floor((data - self.min_val) / self.delta).to(
dtype=torch.long
), torch.frac((data - self.min_val) / self.delta)
momentum = self.momentum if self.hist_.sum() != 0 else 1.0
self.hist_ = torch.scatter_add(
(1.0 - momentum) * self.hist_, 0, idx, momentum * (1 - alpha)
)
self.hist_ = torch.scatter_add(self.hist_, 0, idx + 1, momentum * alpha)
self.hist_ /= self.hist_.sum()

def weigh_loss(self, data):
CDF = torch.cumsum(self.hist_, 0)
idx = torch.floor((data - self.min_val) / self.delta).to(dtype=torch.long)
return CDF[idx] * data
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/losses/histogram_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, n_bins: int = None, delta: float = None, **kwargs):
n_bins = 100

self.delta = delta if delta is not None else 2 / n_bins
self.add_to_recordable_attributes(name="delta", is_stat=True)
self.add_to_recordable_attributes(name="delta", is_stat=False)

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_metric_learning/losses/pnp_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
else:
raise Exception(f"variant <{self.variant}> not available!")

loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1)
loss = torch.sum(loss) / N
loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1)
loss = torch.sum(loss) / torch.sum(safe_N)
if self.variant == "Dq":
loss = 1 - loss

Expand Down
96 changes: 96 additions & 0 deletions src/pytorch_metric_learning/losses/ranked_list_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import warnings

import torch

from ..distances import LpDistance
from ..utils import common_functions as c_f
from .base_metric_loss_function import BaseMetricLossFunction


class RankedListLoss(BaseMetricLossFunction):
r"""Ranked List Loss described in https://arxiv.org/abs/1903.03238
Default parameters correspond to RLL-Simpler, preferred for exploratory analysis.

Args:
* margin (float): margin between positive and negative set
* imbalance (float): tradeoff between positive and negative sets. As the name suggests this takes into account
the imbalance between positive and negative samples in the dataset
* alpha (float): smallest distance between negative points
* Tp & Tn (float): temperatures for, respectively, positive and negative pairs weighting.
"""

def __init__(self, margin, Tn, imbalance=0.5, alpha=None, Tp=0, **kwargs):
super().__init__(**kwargs)

self.margin = margin

assert 0 <= imbalance <= 1, "Imbalance must be between 0 and 1"
self.imbalance = imbalance

if alpha is not None:
self.alpha = alpha
else:
self.alpha = 1 + margin / 2

if Tp > 5 or Tn > 5:
warnings.warn(
"Values of Tp or Tn are too high. Too large temperature values may lead to overflow."
)

self.Tp = Tp
self.Tn = Tn
self.add_to_recordable_attributes(
list_of_names=["imbalance", "alpha", "margin", "Tp", "Tn"], is_stat=False
)

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
c_f.labels_required(labels)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
c_f.indices_tuple_not_supported(indices_tuple)

mat = self.distance(embeddings, embeddings)
# mat.fill_diagonal_(0)
mat = mat - mat * torch.eye(len(mat), device=embeddings.device)
mat = c_f.to_device(mat, device=embeddings.device, dtype=embeddings.dtype)
y = labels.unsqueeze(1) == labels.unsqueeze(0)

P_star = torch.zeros_like(mat)
N_star = torch.zeros_like(mat)
w_p = torch.zeros_like(mat)
w_n = torch.zeros_like(mat)

N_star[(~y) * (mat < self.alpha)] = mat[(~y) * (mat < self.alpha)]
y.fill_diagonal_(False)
P_star[y * (mat > (self.alpha - self.margin))] = mat[
y * (mat > (self.alpha - self.margin))
]

w_p[P_star > 0] = torch.exp(
self.Tp * (P_star[P_star > 0] - (self.alpha - self.margin))
)
w_n[N_star > 0] = torch.exp(self.Tn * (self.alpha - N_star[N_star > 0]))

loss_P = torch.sum(
w_p * (P_star - (self.alpha - self.margin)), dim=1
) / torch.sum(w_p + 1e-5, dim=1)

loss_N = torch.sum(w_n * (self.alpha - N_star), dim=1) / torch.sum(
w_n + 1e-5, dim=1
)

# with torch.no_grad():
# loss_P[loss_P.isnan()] = 0
# loss_N[loss_N.isnan()] = 0

loss_RLL = (1 - self.imbalance) * loss_P + self.imbalance * loss_N

return {
"loss": {
"losses": loss_RLL,
"indices": c_f.torch_arange_from_size(loss_RLL),
"reduction_type": "element",
}
}

def get_default_distance(self):
return LpDistance()
Loading
Loading