mattricesound commited on
Commit
1e53a02
·
1 Parent(s): 8125531

Add L1Loss to loss

Browse files
Files changed (1) hide show
  1. remfx/models.py +9 -8
remfx/models.py CHANGED
@@ -5,8 +5,7 @@ from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
- from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
- from torch.nn import L1Loss
10
  from remfx.utils import FADLoss
11
 
12
  from umx.openunmix.model import OpenUnmix, Separator
@@ -35,7 +34,7 @@ class RemFXModel(pl.LightningModule):
35
  self.metrics = torch.nn.ModuleDict(
36
  {
37
  "SISDR": SISDRLoss(),
38
- "STFT": STFTLoss(),
39
  "FAD": FADLoss(sample_rate=sample_rate),
40
  }
41
  )
@@ -189,16 +188,17 @@ class OpenUnmixModel(torch.nn.Module):
189
  n_fft=self.n_fft,
190
  n_hop=self.hop_length,
191
  )
192
- self.loss_fn = MultiResolutionSTFTLoss(
193
  n_bins=self.num_bins, sample_rate=self.sample_rate
194
  )
 
195
 
196
  def forward(self, batch):
197
  x, target, label = batch
198
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
199
  Y = self.model(X)
200
  sep_out = self.separator(x).squeeze(1)
201
- loss = self.loss_fn(sep_out, target)
202
 
203
  return loss, sep_out
204
 
@@ -211,14 +211,15 @@ class DemucsModel(torch.nn.Module):
211
  super().__init__()
212
  self.model = HDemucs(**kwargs)
213
  self.num_bins = kwargs["nfft"] // 2 + 1
214
- self.loss_fn = MultiResolutionSTFTLoss(
215
- n_bins=self.num_bins, sample_rate=sample_rate
216
  )
 
217
 
218
  def forward(self, batch):
219
  x, target, label = batch
220
  output = self.model(x).squeeze(1)
221
- loss = self.loss_fn(output, target)
222
  return loss, output
223
 
224
  def sample(self, x: Tensor) -> Tensor:
 
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
+ from auraloss.freq import MultiResolutionSTFTLoss
 
9
  from remfx.utils import FADLoss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
 
34
  self.metrics = torch.nn.ModuleDict(
35
  {
36
  "SISDR": SISDRLoss(),
37
+ "STFT": MultiResolutionSTFTLoss(),
38
  "FAD": FADLoss(sample_rate=sample_rate),
39
  }
40
  )
 
188
  n_fft=self.n_fft,
189
  n_hop=self.hop_length,
190
  )
191
+ self.mrstftloss = MultiResolutionSTFTLoss(
192
  n_bins=self.num_bins, sample_rate=self.sample_rate
193
  )
194
+ self.l1loss = torch.nn.L1Loss()
195
 
196
  def forward(self, batch):
197
  x, target, label = batch
198
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
199
  Y = self.model(X)
200
  sep_out = self.separator(x).squeeze(1)
201
+ loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
202
 
203
  return loss, sep_out
204
 
 
211
  super().__init__()
212
  self.model = HDemucs(**kwargs)
213
  self.num_bins = kwargs["nfft"] // 2 + 1
214
+ self.mrstftloss = MultiResolutionSTFTLoss(
215
+ n_bins=self.num_bins, sample_rate=self.sample_rate
216
  )
217
+ self.l1loss = torch.nn.L1Loss()
218
 
219
  def forward(self, batch):
220
  x, target, label = batch
221
  output = self.model(x).squeeze(1)
222
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target)
223
  return loss, output
224
 
225
  def sample(self, x: Tensor) -> Tensor: