Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from .. import AudioSignal | |
| class L1Loss(nn.L1Loss): | |
| """L1 Loss between AudioSignals. Defaults | |
| to comparing ``audio_data``, but any | |
| attribute of an AudioSignal can be used. | |
| Parameters | |
| ---------- | |
| attribute : str, optional | |
| Attribute of signal to compare, defaults to ``audio_data``. | |
| weight : float, optional | |
| Weight of this loss, defaults to 1.0. | |
| """ | |
| def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): | |
| self.attribute = attribute | |
| self.weight = weight | |
| super().__init__(**kwargs) | |
| def forward(self, x: AudioSignal, y: AudioSignal): | |
| """ | |
| Parameters | |
| ---------- | |
| x : AudioSignal | |
| Estimate AudioSignal | |
| y : AudioSignal | |
| Reference AudioSignal | |
| Returns | |
| ------- | |
| torch.Tensor | |
| L1 loss between AudioSignal attributes. | |
| """ | |
| if isinstance(x, AudioSignal): | |
| x = getattr(x, self.attribute) | |
| y = getattr(y, self.attribute) | |
| return super().forward(x, y) | |
| class SISDRLoss(nn.Module): | |
| """ | |
| Computes the Scale-Invariant Source-to-Distortion Ratio between a batch | |
| of estimated and reference audio signals or aligned features. | |
| Parameters | |
| ---------- | |
| scaling : int, optional | |
| Whether to use scale-invariant (True) or | |
| signal-to-noise ratio (False), by default True | |
| reduction : str, optional | |
| How to reduce across the batch (either 'mean', | |
| 'sum', or none).], by default ' mean' | |
| zero_mean : int, optional | |
| Zero mean the references and estimates before | |
| computing the loss, by default True | |
| clip_min : int, optional | |
| The minimum possible loss value. Helps network | |
| to not focus on making already good examples better, by default None | |
| weight : float, optional | |
| Weight of this loss, defaults to 1.0. | |
| """ | |
| def __init__( | |
| self, | |
| scaling: int = True, | |
| reduction: str = "mean", | |
| zero_mean: int = True, | |
| clip_min: int = None, | |
| weight: float = 1.0, | |
| ): | |
| self.scaling = scaling | |
| self.reduction = reduction | |
| self.zero_mean = zero_mean | |
| self.clip_min = clip_min | |
| self.weight = weight | |
| super().__init__() | |
| def forward(self, x: AudioSignal, y: AudioSignal): | |
| eps = 1e-8 | |
| # nb, nc, nt | |
| if isinstance(x, AudioSignal): | |
| references = x.audio_data | |
| estimates = y.audio_data | |
| else: | |
| references = x | |
| estimates = y | |
| nb = references.shape[0] | |
| references = references.reshape(nb, 1, -1).permute(0, 2, 1) | |
| estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) | |
| # samples now on axis 1 | |
| if self.zero_mean: | |
| mean_reference = references.mean(dim=1, keepdim=True) | |
| mean_estimate = estimates.mean(dim=1, keepdim=True) | |
| else: | |
| mean_reference = 0 | |
| mean_estimate = 0 | |
| _references = references - mean_reference | |
| _estimates = estimates - mean_estimate | |
| references_projection = (_references**2).sum(dim=-2) + eps | |
| references_on_estimates = (_estimates * _references).sum(dim=-2) + eps | |
| scale = ( | |
| (references_on_estimates / references_projection).unsqueeze(1) | |
| if self.scaling | |
| else 1 | |
| ) | |
| e_true = scale * _references | |
| e_res = _estimates - e_true | |
| signal = (e_true**2).sum(dim=1) | |
| noise = (e_res**2).sum(dim=1) | |
| sdr = -10 * torch.log10(signal / noise + eps) | |
| if self.clip_min is not None: | |
| sdr = torch.clamp(sdr, min=self.clip_min) | |
| if self.reduction == "mean": | |
| sdr = sdr.mean() | |
| elif self.reduction == "sum": | |
| sdr = sdr.sum() | |
| return sdr | |