mattricesound commited on
Commit
7e4b346
·
1 Parent(s): c9ec30f

Remove umx

Browse files
Files changed (1) hide show
  1. remfx/models.py +0 -49
remfx/models.py CHANGED
@@ -6,7 +6,6 @@ from torch import Tensor, nn
6
  from torchaudio.models import HDemucs
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss
9
- from umx.openunmix.model import OpenUnmix, Separator
10
 
11
  from remfx.utils import spectrogram
12
  from remfx.tcn import TCN
@@ -256,54 +255,6 @@ class RemFX(pl.LightningModule):
256
  return loss
257
 
258
 
259
- class OpenUnmixModel(nn.Module):
260
- def __init__(
261
- self,
262
- n_fft: int = 2048,
263
- hop_length: int = 512,
264
- n_channels: int = 1,
265
- alpha: float = 0.3,
266
- sample_rate: int = 22050,
267
- ):
268
- super().__init__()
269
- self.n_channels = n_channels
270
- self.n_fft = n_fft
271
- self.hop_length = hop_length
272
- self.alpha = alpha
273
- window = torch.hann_window(n_fft)
274
- self.register_buffer("window", window)
275
-
276
- self.num_bins = self.n_fft // 2 + 1
277
- self.sample_rate = sample_rate
278
- self.model = OpenUnmix(
279
- nb_channels=self.n_channels,
280
- nb_bins=self.num_bins,
281
- )
282
- self.separator = Separator(
283
- target_models={"other": self.model},
284
- nb_channels=self.n_channels,
285
- sample_rate=self.sample_rate,
286
- n_fft=self.n_fft,
287
- n_hop=self.hop_length,
288
- )
289
- self.mrstftloss = MultiResolutionSTFTLoss(
290
- n_bins=self.num_bins, sample_rate=self.sample_rate
291
- )
292
- self.l1loss = nn.L1Loss()
293
-
294
- def forward(self, batch):
295
- x, target = batch
296
- X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
297
- Y = self.model(X)
298
- sep_out = self.separator(x).squeeze(1)
299
- loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
300
-
301
- return loss, sep_out
302
-
303
- def sample(self, x: Tensor) -> Tensor:
304
- return self.separator(x).squeeze(1)
305
-
306
-
307
  class DemucsModel(nn.Module):
308
  def __init__(self, sample_rate, **kwargs) -> None:
309
  super().__init__()
 
6
  from torchaudio.models import HDemucs
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss
 
9
 
10
  from remfx.utils import spectrogram
11
  from remfx.tcn import TCN
 
255
  return loss
256
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  class DemucsModel(nn.Module):
259
  def __init__(self, sample_rate, **kwargs) -> None:
260
  super().__init__()