Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training support to Meta's EnCodec #33956

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
"deepspeed>=0.9.3",
"diffusers",
"dill<0.3.5",
"einops",
"evaluate>=0.2.0",
"faiss-cpu",
"fastapi",
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,8 @@
)
_import_structure["models.encodec"].extend(
[
"EncodecDiscriminator",
"EncodecDiscriminatorConfig",
"EncodecModel",
"EncodecPreTrainedModel",
]
Expand Down Expand Up @@ -6923,6 +6925,8 @@
load_tf_weights_in_electra,
)
from .models.encodec import (
EncodecDiscriminator,
EncodecDiscriminatorConfig,
EncodecModel,
EncodecPreTrainedModel,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"deepspeed": "deepspeed>=0.9.3",
"diffusers": "diffusers",
"dill": "dill<0.3.5",
"einops": "einops",
"evaluate": "evaluate>=0.2.0",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
Expand Down
56 changes: 56 additions & 0 deletions src/transformers/loss/loss_encodec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn.functional as F

def EncodecLoss(
model,
input_values,
audio_values
):
"""
Computes the reconstruction and commitment losses for the Encodec model.

Args:
model: The EncodecModel instance.
input_values (torch.Tensor): Original input audio.
audio_values (torch.Tensor): Reconstructed audio from the model.
audio_codes (torch.Tensor): Discrete codes from the quantizer.
padding_mask (torch.Tensor): Padding mask used during encoding.
config: Model configuration.

Returns:
tuple: A tuple containing (reconstruction_loss, commitment_loss).
"""
# Compute commitment loss
embeddings = model.encoder(input_values)
_, quantization_steps = model.quantizer.encode(embeddings, bandwidth=None)

commitment_loss = torch.tensor(0.0, device=input_values.device)
for residual, quantize in quantization_steps:
loss = F.mse_loss(quantize.permute(0, 2, 1), residual.permute(0, 2, 1))
commitment_loss += loss
commitment_loss *= model.commitment_weight

# Compute reconstruction loss
# Time domain loss
time_loss = F.l1_loss(audio_values, input_values)

# Frequency domain loss
scales = [2**i for i in range(5, 12)]
frequency_loss = 0.0
for scale in scales:
n_fft = scale
hop_length = scale // 4
S_x = model.compute_mel_spectrogram(input_values, n_fft, hop_length, n_mels=64)
S_x_hat = model.compute_mel_spectrogram(audio_values, n_fft, hop_length, n_mels=64)
l1 = F.l1_loss(S_x_hat, S_x)
l2 = F.mse_loss(S_x_hat, S_x)
frequency_loss += l1 + l2

frequency_loss = frequency_loss / (len(scales) * 2)

# Combine losses
lambda_t = 1.0 # You can adjust these weights if needed
lambda_f = 1.0
reconstruction_loss = lambda_t * time_loss + lambda_f * frequency_loss

return reconstruction_loss, commitment_loss
2 changes: 2 additions & 0 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss
from .loss_encodec import EncodecLoss


def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
Expand Down Expand Up @@ -111,4 +112,5 @@ def ForTokenClassification(logits, labels, config, **kwargs):
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"Encodec": EncodecLoss,
}
16 changes: 16 additions & 0 deletions src/transformers/models/encodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
_import_structure = {
"configuration_encodec": ["EncodecConfig"],
"feature_extraction_encodec": ["EncodecFeatureExtractor"],
"loss_encodec": [
"Balancer",
"compute_discriminator_loss",
"compute_feature_matching_loss",
"compute_generator_adv_loss",
],
}

try:
Expand All @@ -34,13 +40,21 @@
_import_structure["modeling_encodec"] = [
"EncodecModel",
"EncodecPreTrainedModel",
"EncodecDiscriminatorConfig",
"EncodecDiscriminator",
]

if TYPE_CHECKING:
from .configuration_encodec import (
EncodecConfig,
)
from .feature_extraction_encodec import EncodecFeatureExtractor
from .loss_encodec import (
Balancer,
compute_discriminator_loss,
compute_feature_matching_loss,
compute_generator_adv_loss,
)

try:
if not is_torch_available():
Expand All @@ -49,6 +63,8 @@
pass
else:
from .modeling_encodec import (
EncodecDiscriminator,
EncodecDiscriminatorConfig,
EncodecModel,
EncodecPreTrainedModel,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/encodec/configuration_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class EncodecConfig(PretrainedConfig):
```"""

model_type = "encodec"
loss_type = "Encodec"

def __init__(
self,
Expand Down
219 changes: 219 additions & 0 deletions src/transformers/models/encodec/loss_encodec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import typing as tp
from collections import defaultdict
from typing import List

import torch
import torch.nn.functional as F
from torch import autograd

"""
Balancer code directly copied from: https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py
"""
class Balancer:
"""Loss balancer.

The loss balancer combines losses together to compute gradients for the backward.
A call to the balancer will weight the losses according the specified weight coefficients.
A call to the backward method of the balancer will compute the gradients, combining all the losses and
potentially rescaling the gradients, which can help stabilize the training and reasonate
about multiple losses with varying scales.

Expected usage:
weights = {'loss_a': 1, 'loss_b': 4}
balancer = Balancer(weights, ...)
losses: dict = {}
losses['loss_a'] = compute_loss_a(x, y)
losses['loss_b'] = compute_loss_b(x, y)
if model.training():
balancer.backward(losses, x)

..Warning:: It is unclear how this will interact with DistributedDataParallel,
in particular if you have some losses not handled by the balancer. In that case
you can use `encodec.distrib.sync_grad(model.parameters())` and
`encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative.

Args:
weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
from the backward method to match the weights keys to assign weight to each of the provided loss.
rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just
a regular weighted sum of losses.
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True.
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
when rescaling the gradients.
epsilon (float): Epsilon value for numerical stability.
monitor (bool): Whether to store additional ratio for each loss key in metrics.
"""

def __init__(
self,
weights: tp.Dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
epsilon: float = 1e-12,
monitor: bool = False,
):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm
self.averager = averager(ema_decay)
self.epsilon = epsilon
self.monitor = monitor
self.rescale_grads = rescale_grads
self._metrics: tp.Dict[str, tp.Any] = {}

@property
def metrics(self):
return self._metrics

def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
norms = {}
grads = {}
for name, loss in losses.items():
(grad,) = autograd.grad(loss, [input], retain_graph=True, allow_unused=True)
if grad is not None:
if self.per_batch_item:
dims = tuple(range(1, grad.dim()))
norm = grad.norm(dim=dims).mean()
else:
norm = grad.norm()
norms[name] = norm
grads[name] = grad

count = 1
if self.per_batch_item:
count = len(next(iter(grads.values())))
avg_norms = average_metrics(self.averager(norms), count)
total = sum(avg_norms.values())

self._metrics = {}
if self.monitor:
for k, v in avg_norms.items():
self._metrics[f"ratio_{k}"] = v / total

total_weights = sum([self.weights[k] for k in avg_norms])
ratios = {k: w / total_weights for k, w in self.weights.items()}

out_grad: tp.Any = 0
for name, avg_norm in avg_norms.items():
if self.rescale_grads:
scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
grad = grads[name] * scale
else:
grad = self.weights[name] * grads[name]
out_grad += grad
input.backward(out_grad)


def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1


def is_distributed():
return world_size() > 1


def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)


def average_metrics(metrics: tp.Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unnormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))


def averager(beta: float = 1):
"""
Exponential Moving Average callback.
Returns a single function that can be called to repeatidly update the EMA
with a dict of metrics. The callback will return
the new averaged dict of metrics.

Note that for `beta=1`, this is just plain averaging.
"""
fix: tp.Dict[str, float] = defaultdict(float)
total: tp.Dict[str, float] = defaultdict(float)

def _update(metrics: tp.Dict[str, tp.Any], weight: float = 1) -> tp.Dict[str, float]:
nonlocal total, fix
for key, value in metrics.items():
total[key] = total[key] * beta + weight * float(value)
fix[key] = fix[key] * beta + weight
return {key: tot / fix[key] for key, tot in total.items()}

return _update


def compute_discriminator_loss(
real_logits: List[torch.Tensor], fake_logits: List[torch.Tensor], num_discriminators: int
) -> torch.Tensor:
"""
Compute the discriminator loss based on real and fake logits.

Args:
real_logits (List[torch.Tensor]): List of real logits from discriminators.
fake_logits (List[torch.Tensor]): List of fake logits from discriminators.
num_discriminators (int): Number of discriminators.

Returns:
torch.Tensor: The computed discriminator loss.
"""
loss = 0.0
for real_logit, fake_logit in zip(real_logits, fake_logits):
loss += torch.mean(F.relu(1 - real_logit)) + torch.mean(F.relu(1 + fake_logit))
return loss / num_discriminators


def compute_generator_adv_loss(fake_logits: List[torch.Tensor], num_discriminators: int) -> torch.Tensor:
"""
Compute the generator adversarial loss using fake logits.

Args:
fake_logits (List[torch.Tensor]): List of fake logits from discriminators.
num_discriminators (int): Number of discriminators.

Returns:
torch.Tensor: The computed generator adversarial loss.
"""
loss = 0.0
for fake_logit in fake_logits:
loss += torch.mean(F.relu(1 - fake_logit))
return loss / num_discriminators


def compute_feature_matching_loss(
real_features: List[List[torch.Tensor]], fake_features: List[List[torch.Tensor]], num_discriminators: int
):
"""
Compute the feature matching loss between real and fake features.

Args:
real_features (List[List[torch.Tensor]]): List of lists containing real features from each discriminator.
fake_features (List[List[torch.Tensor]]): List of lists containing fake features from each discriminator.
num_discriminators (int): Number of discriminators.

Returns:
torch.Tensor: The computed feature matching loss.
"""
fm_loss = 0
for k in range(num_discriminators):
for real_feat, fake_feat in zip(real_features[k], fake_features[k]):
fm_loss += F.l1_loss(fake_feat, real_feat.detach()) / torch.mean(torch.abs(real_feat.detach()))
fm_loss /= num_discriminators * len(real_features[0])
return fm_loss
Loading