mattricesound commited on
Commit
1dd1464
·
1 Parent(s): e3f4ef0

Add demucs model

Browse files
Files changed (4) hide show
  1. README.md +17 -0
  2. exp/demucs.yaml +19 -0
  3. remfx/models.py +20 -0
  4. scripts/train.py +3 -0
README.md CHANGED
@@ -19,3 +19,20 @@ To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-li
19
 
20
  Ex. `python train.py exp=umx trainer.accelerator='gpu' trainer.devices=-1`
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  Ex. `python train.py exp=umx trainer.accelerator='gpu' trainer.devices=-1`
21
 
22
+ ### Effects
23
+ Default effect is RAT (distortion). Effect choices:
24
+ - BluesDriver
25
+ - Clean
26
+ - Flanger
27
+ - Phaser
28
+ - RAT
29
+ - Sweep Echo
30
+ - TubeScreamer
31
+ - Chorus
32
+ - Digital Delay
33
+ - Hall Reverb
34
+ - Plate Reverb
35
+ - Spring Reverb
36
+ - TapeEcho
37
+
38
+ Change effect by adding `+datamodule.dataset.effect_types=["{Effect}"]` to the command-line
exp/demucs.yaml CHANGED
@@ -1 +1,20 @@
1
  # @package _global_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFXModel
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.DemucsModel
12
+ sources: ["other"]
13
+ audio_channels: 1
14
+ nfft: 4096
15
+ sample_rate: ${sample_rate}
16
+
17
+
18
+ datamodule:
19
+ dataset:
20
+ effect_types: ["RAT"]
remfx/models.py CHANGED
@@ -9,6 +9,7 @@ from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
  from torch.nn import L1Loss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
 
12
 
13
 
14
  class RemFXModel(pl.LightningModule):
@@ -148,6 +149,25 @@ class OpenUnmixModel(torch.nn.Module):
148
  return self.separator(x).squeeze(1)
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  class DiffusionGenerationModel(nn.Module):
152
  def __init__(self, n_channels: int = 1):
153
  super().__init__()
 
9
  from torch.nn import L1Loss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
12
+ from torchaudio.models import HDemucs
13
 
14
 
15
  class RemFXModel(pl.LightningModule):
 
149
  return self.separator(x).squeeze(1)
150
 
151
 
152
+ class DemucsModel(torch.nn.Module):
153
+ def __init__(self, sample_rate, **kwargs) -> None:
154
+ super().__init__()
155
+ self.model = HDemucs(**kwargs)
156
+ self.num_bins = kwargs["nfft"] // 2 + 1
157
+ self.loss_fn = MultiResolutionSTFTLoss(
158
+ n_bins=self.num_bins, sample_rate=sample_rate
159
+ )
160
+
161
+ def forward(self, batch):
162
+ x, target, label = batch
163
+ output = self.model(x).squeeze(1)
164
+ loss = self.loss_fn(output, target)
165
+ return loss, output
166
+
167
+ def sample(self, x: Tensor) -> Tensor:
168
+ return self.model(x).squeeze(1)
169
+
170
+
171
  class DiffusionGenerationModel(nn.Module):
172
  def __init__(self, n_channels: int = 1):
173
  super().__init__()
scripts/train.py CHANGED
@@ -2,6 +2,7 @@ import pytorch_lightning as pl
2
  import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
 
5
 
6
  log = utils.get_logger(__name__)
7
 
@@ -39,6 +40,8 @@ def main(cfg: DictConfig):
39
  callbacks=callbacks,
40
  logger=logger,
41
  )
 
 
42
  trainer.fit(model=model, datamodule=datamodule)
43
 
44
 
 
2
  import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
+ from pytorch_lightning.utilities.model_summary import ModelSummary
6
 
7
  log = utils.get_logger(__name__)
8
 
 
40
  callbacks=callbacks,
41
  logger=logger,
42
  )
43
+ summary = ModelSummary(model)
44
+ print(summary)
45
  trainer.fit(model=model, datamodule=datamodule)
46
 
47