mattricesound commited on
Commit
1ff07dc
·
1 Parent(s): d8d3e30

Add new loss to umx

Browse files
Files changed (1) hide show
  1. remfx/models.py +2 -2
remfx/models.py CHANGED
@@ -237,7 +237,7 @@ class OpenUnmixModel(torch.nn.Module):
237
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
238
  Y = self.model(X)
239
  sep_out = self.separator(x).squeeze(1)
240
- loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
241
 
242
  return loss, sep_out
243
 
@@ -258,7 +258,7 @@ class DemucsModel(torch.nn.Module):
258
  def forward(self, batch):
259
  x, target, label = batch
260
  output = self.model(x).squeeze(1)
261
- loss = self.mrstftloss(output, target) + self.l1loss(output, target)
262
  return loss, output
263
 
264
  def sample(self, x: Tensor) -> Tensor:
 
237
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
238
  Y = self.model(X)
239
  sep_out = self.separator(x).squeeze(1)
240
+ loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
241
 
242
  return loss, sep_out
243
 
 
258
  def forward(self, batch):
259
  x, target, label = batch
260
  output = self.model(x).squeeze(1)
261
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
262
  return loss, output
263
 
264
  def sample(self, x: Tensor) -> Tensor: