Spaces:
Runtime error
Runtime error
Commit
·
7e4b346
1
Parent(s):
c9ec30f
Remove umx
Browse files- remfx/models.py +0 -49
remfx/models.py
CHANGED
@@ -6,7 +6,6 @@ from torch import Tensor, nn
|
|
6 |
from torchaudio.models import HDemucs
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss
|
9 |
-
from umx.openunmix.model import OpenUnmix, Separator
|
10 |
|
11 |
from remfx.utils import spectrogram
|
12 |
from remfx.tcn import TCN
|
@@ -256,54 +255,6 @@ class RemFX(pl.LightningModule):
|
|
256 |
return loss
|
257 |
|
258 |
|
259 |
-
class OpenUnmixModel(nn.Module):
|
260 |
-
def __init__(
|
261 |
-
self,
|
262 |
-
n_fft: int = 2048,
|
263 |
-
hop_length: int = 512,
|
264 |
-
n_channels: int = 1,
|
265 |
-
alpha: float = 0.3,
|
266 |
-
sample_rate: int = 22050,
|
267 |
-
):
|
268 |
-
super().__init__()
|
269 |
-
self.n_channels = n_channels
|
270 |
-
self.n_fft = n_fft
|
271 |
-
self.hop_length = hop_length
|
272 |
-
self.alpha = alpha
|
273 |
-
window = torch.hann_window(n_fft)
|
274 |
-
self.register_buffer("window", window)
|
275 |
-
|
276 |
-
self.num_bins = self.n_fft // 2 + 1
|
277 |
-
self.sample_rate = sample_rate
|
278 |
-
self.model = OpenUnmix(
|
279 |
-
nb_channels=self.n_channels,
|
280 |
-
nb_bins=self.num_bins,
|
281 |
-
)
|
282 |
-
self.separator = Separator(
|
283 |
-
target_models={"other": self.model},
|
284 |
-
nb_channels=self.n_channels,
|
285 |
-
sample_rate=self.sample_rate,
|
286 |
-
n_fft=self.n_fft,
|
287 |
-
n_hop=self.hop_length,
|
288 |
-
)
|
289 |
-
self.mrstftloss = MultiResolutionSTFTLoss(
|
290 |
-
n_bins=self.num_bins, sample_rate=self.sample_rate
|
291 |
-
)
|
292 |
-
self.l1loss = nn.L1Loss()
|
293 |
-
|
294 |
-
def forward(self, batch):
|
295 |
-
x, target = batch
|
296 |
-
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
297 |
-
Y = self.model(X)
|
298 |
-
sep_out = self.separator(x).squeeze(1)
|
299 |
-
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
|
300 |
-
|
301 |
-
return loss, sep_out
|
302 |
-
|
303 |
-
def sample(self, x: Tensor) -> Tensor:
|
304 |
-
return self.separator(x).squeeze(1)
|
305 |
-
|
306 |
-
|
307 |
class DemucsModel(nn.Module):
|
308 |
def __init__(self, sample_rate, **kwargs) -> None:
|
309 |
super().__init__()
|
|
|
6 |
from torchaudio.models import HDemucs
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss
|
|
|
9 |
|
10 |
from remfx.utils import spectrogram
|
11 |
from remfx.tcn import TCN
|
|
|
255 |
return loss
|
256 |
|
257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
class DemucsModel(nn.Module):
|
259 |
def __init__(self, sample_rate, **kwargs) -> None:
|
260 |
super().__init__()
|