Skip to content

Commit

Permalink
warn in non-distributed setting
Browse files Browse the repository at this point in the history
  • Loading branch information
elisim committed Jul 17, 2024
1 parent a9f13b1 commit 946eeb8
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch

from ..losses import BaseMetricLossFunction, CrossBatchMemory
Expand Down Expand Up @@ -100,6 +102,12 @@ def forward(
ref_labels=None,
enqueue_mask=None,
):
if not is_distributed():
warnings.warn(
"DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is."
)
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)

world_size = torch.distributed.get_world_size()
common_args = [embeddings, labels, indices_tuple, ref_emb, ref_labels, world_size]
if isinstance(self.loss, CrossBatchMemory):
Expand Down

0 comments on commit 946eeb8

Please sign in to comment.