Spaces:
Runtime error
Runtime error
Commit
·
1ff07dc
1
Parent(s):
d8d3e30
Add new loss to umx
Browse files- remfx/models.py +2 -2
remfx/models.py
CHANGED
@@ -237,7 +237,7 @@ class OpenUnmixModel(torch.nn.Module):
|
|
237 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
238 |
Y = self.model(X)
|
239 |
sep_out = self.separator(x).squeeze(1)
|
240 |
-
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
|
241 |
|
242 |
return loss, sep_out
|
243 |
|
@@ -258,7 +258,7 @@ class DemucsModel(torch.nn.Module):
|
|
258 |
def forward(self, batch):
|
259 |
x, target, label = batch
|
260 |
output = self.model(x).squeeze(1)
|
261 |
-
loss = self.mrstftloss(output, target) + self.l1loss(output, target)
|
262 |
return loss, output
|
263 |
|
264 |
def sample(self, x: Tensor) -> Tensor:
|
|
|
237 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
238 |
Y = self.model(X)
|
239 |
sep_out = self.separator(x).squeeze(1)
|
240 |
+
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
|
241 |
|
242 |
return loss, sep_out
|
243 |
|
|
|
258 |
def forward(self, batch):
|
259 |
x, target, label = batch
|
260 |
output = self.model(x).squeeze(1)
|
261 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
262 |
return loss, output
|
263 |
|
264 |
def sample(self, x: Tensor) -> Tensor:
|