Spaces:
Sleeping
Sleeping
Commit
·
1e53a02
1
Parent(s):
8125531
Add L1Loss to loss
Browse files- remfx/models.py +9 -8
remfx/models.py
CHANGED
@@ -5,8 +5,7 @@ from einops import rearrange
|
|
5 |
import wandb
|
6 |
from audio_diffusion_pytorch import DiffusionModel
|
7 |
from auraloss.time import SISDRLoss
|
8 |
-
from auraloss.freq import MultiResolutionSTFTLoss
|
9 |
-
from torch.nn import L1Loss
|
10 |
from remfx.utils import FADLoss
|
11 |
|
12 |
from umx.openunmix.model import OpenUnmix, Separator
|
@@ -35,7 +34,7 @@ class RemFXModel(pl.LightningModule):
|
|
35 |
self.metrics = torch.nn.ModuleDict(
|
36 |
{
|
37 |
"SISDR": SISDRLoss(),
|
38 |
-
"STFT":
|
39 |
"FAD": FADLoss(sample_rate=sample_rate),
|
40 |
}
|
41 |
)
|
@@ -189,16 +188,17 @@ class OpenUnmixModel(torch.nn.Module):
|
|
189 |
n_fft=self.n_fft,
|
190 |
n_hop=self.hop_length,
|
191 |
)
|
192 |
-
self.
|
193 |
n_bins=self.num_bins, sample_rate=self.sample_rate
|
194 |
)
|
|
|
195 |
|
196 |
def forward(self, batch):
|
197 |
x, target, label = batch
|
198 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
199 |
Y = self.model(X)
|
200 |
sep_out = self.separator(x).squeeze(1)
|
201 |
-
loss = self.
|
202 |
|
203 |
return loss, sep_out
|
204 |
|
@@ -211,14 +211,15 @@ class DemucsModel(torch.nn.Module):
|
|
211 |
super().__init__()
|
212 |
self.model = HDemucs(**kwargs)
|
213 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
214 |
-
self.
|
215 |
-
n_bins=self.num_bins, sample_rate=sample_rate
|
216 |
)
|
|
|
217 |
|
218 |
def forward(self, batch):
|
219 |
x, target, label = batch
|
220 |
output = self.model(x).squeeze(1)
|
221 |
-
loss = self.
|
222 |
return loss, output
|
223 |
|
224 |
def sample(self, x: Tensor) -> Tensor:
|
|
|
5 |
import wandb
|
6 |
from audio_diffusion_pytorch import DiffusionModel
|
7 |
from auraloss.time import SISDRLoss
|
8 |
+
from auraloss.freq import MultiResolutionSTFTLoss
|
|
|
9 |
from remfx.utils import FADLoss
|
10 |
|
11 |
from umx.openunmix.model import OpenUnmix, Separator
|
|
|
34 |
self.metrics = torch.nn.ModuleDict(
|
35 |
{
|
36 |
"SISDR": SISDRLoss(),
|
37 |
+
"STFT": MultiResolutionSTFTLoss(),
|
38 |
"FAD": FADLoss(sample_rate=sample_rate),
|
39 |
}
|
40 |
)
|
|
|
188 |
n_fft=self.n_fft,
|
189 |
n_hop=self.hop_length,
|
190 |
)
|
191 |
+
self.mrstftloss = MultiResolutionSTFTLoss(
|
192 |
n_bins=self.num_bins, sample_rate=self.sample_rate
|
193 |
)
|
194 |
+
self.l1loss = torch.nn.L1Loss()
|
195 |
|
196 |
def forward(self, batch):
|
197 |
x, target, label = batch
|
198 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
199 |
Y = self.model(X)
|
200 |
sep_out = self.separator(x).squeeze(1)
|
201 |
+
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
|
202 |
|
203 |
return loss, sep_out
|
204 |
|
|
|
211 |
super().__init__()
|
212 |
self.model = HDemucs(**kwargs)
|
213 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
214 |
+
self.mrstftloss = MultiResolutionSTFTLoss(
|
215 |
+
n_bins=self.num_bins, sample_rate=self.sample_rate
|
216 |
)
|
217 |
+
self.l1loss = torch.nn.L1Loss()
|
218 |
|
219 |
def forward(self, batch):
|
220 |
x, target, label = batch
|
221 |
output = self.model(x).squeeze(1)
|
222 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target)
|
223 |
return loss, output
|
224 |
|
225 |
def sample(self, x: Tensor) -> Tensor:
|