mattricesound commited on
Commit
e0a5f6f
·
1 Parent(s): 1b89540

Add vocalset and FAD

Browse files
Files changed (6) hide show
  1. config.yaml +10 -4
  2. config_guitarset.yaml +52 -0
  3. exp/umx.yaml +8 -1
  4. remfx/datasets.py +109 -0
  5. remfx/models.py +33 -1
  6. shell_vars.sh +1 -1
config.yaml CHANGED
@@ -19,13 +19,19 @@ callbacks:
19
  filename: '{epoch:02d}-{valid_loss:.3f}'
20
 
21
  datamodule:
22
- _target_: remfx.datasets.Datamodule
23
- dataset:
24
- _target_: remfx.datasets.GuitarSet
25
  sample_rate: ${sample_rate}
26
  root: ${oc.env:DATASET_ROOT}
27
  chunk_size_in_sec: 6
28
- val_split: 0.2
 
 
 
 
 
 
29
  batch_size: 16
30
  num_workers: 8
31
  pin_memory: True
 
19
  filename: '{epoch:02d}-{valid_loss:.3f}'
20
 
21
  datamodule:
22
+ _target_: remfx.datasets.VocalSetDatamodule
23
+ train_dataset:
24
+ _target_: remfx.datasets.VocalSet
25
  sample_rate: ${sample_rate}
26
  root: ${oc.env:DATASET_ROOT}
27
  chunk_size_in_sec: 6
28
+ mode: "train"
29
+ val_dataset:
30
+ _target_: remfx.datasets.VocalSet
31
+ sample_rate: ${sample_rate}
32
+ root: ${oc.env:DATASET_ROOT}
33
+ chunk_size_in_sec: 6
34
+ mode: "val"
35
  batch_size: 16
36
  num_workers: 8
37
  pin_memory: True
config_guitarset.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - exp: null
4
+ seed: 12345
5
+ train: True
6
+ sample_rate: 48000
7
+ logs_dir: "./logs"
8
+ log_every_n_steps: 1000
9
+
10
+ callbacks:
11
+ model_checkpoint:
12
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
13
+ monitor: "valid_loss" # name of the logged metric which determines when model is improving
14
+ save_top_k: 1 # save k best models (determined by above metric)
15
+ save_last: True # additionaly always save model from last epoch
16
+ mode: "min" # can be "max" or "min"
17
+ verbose: False
18
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
19
+ filename: '{epoch:02d}-{valid_loss:.3f}'
20
+
21
+ datamodule:
22
+ _target_: remfx.datasets.Datamodule
23
+ dataset:
24
+ _target_: remfx.datasets.GuitarSet
25
+ sample_rate: ${sample_rate}
26
+ root: ${oc.env:DATASET_ROOT}
27
+ chunk_size_in_sec: 6
28
+ val_split: 0.2
29
+ batch_size: 16
30
+ num_workers: 8
31
+ pin_memory: True
32
+ persistent_workers: True
33
+
34
+ logger:
35
+ _target_: pytorch_lightning.loggers.WandbLogger
36
+ project: ${oc.env:WANDB_PROJECT}
37
+ entity: ${oc.env:WANDB_ENTITY}
38
+ # offline: False # set True to store all logs only locally
39
+ job_type: "train"
40
+ group: ""
41
+ save_dir: "."
42
+
43
+ trainer:
44
+ _target_: pytorch_lightning.Trainer
45
+ precision: 32 # Precision used for tensors, default `32`
46
+ min_epochs: 0
47
+ max_epochs: -1
48
+ enable_model_summary: False
49
+ log_every_n_steps: 1 # Logs metrics every N batches
50
+ accumulate_grad_batches: 1
51
+ accelerator: null
52
+ devices: 1
exp/umx.yaml CHANGED
@@ -15,7 +15,14 @@ model:
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
  datamodule:
18
- dataset:
 
 
 
 
 
 
 
19
  effect_types:
20
  Distortion:
21
  _target_: remfx.effects.RandomPedalboardDistortion
 
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
  datamodule:
18
+ train_dataset:
19
+ effect_types:
20
+ Distortion:
21
+ _target_: remfx.effects.RandomPedalboardDistortion
22
+ sample_rate: ${sample_rate}
23
+ min_drive_db: -10
24
+ max_drive_db: 50
25
+ val_dataset:
26
  effect_types:
27
  Distortion:
28
  _target_: remfx.effects.RandomPedalboardDistortion
remfx/datasets.py CHANGED
@@ -20,6 +20,7 @@ from pedalboard import (
20
 
21
  # https://zenodo.org/record/7044411/ -> GuitarFX
22
  # https://zenodo.org/record/3371780 -> GuitarSet
 
23
 
24
  deterministic_effects = {
25
  "Distortion": Pedalboard([Distortion()]),
@@ -173,6 +174,74 @@ class GuitarSet(Dataset):
173
  return (normalized_input, normalized_target, effect_name)
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def create_random_chunks(
177
  audio_file: str, chunk_size: int, num_chunks: int
178
  ) -> Tuple[List[Tuple[int, int]], int]:
@@ -249,3 +318,43 @@ class Datamodule(pl.LightningDataModule):
249
  pin_memory=self.pin_memory,
250
  shuffle=False,
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # https://zenodo.org/record/7044411/ -> GuitarFX
22
  # https://zenodo.org/record/3371780 -> GuitarSet
23
+ # https://zenodo.org/record/1193957 -> VocalSet
24
 
25
  deterministic_effects = {
26
  "Distortion": Pedalboard([Distortion()]),
 
174
  return (normalized_input, normalized_target, effect_name)
175
 
176
 
177
+ class VocalSet(Dataset):
178
+ def __init__(
179
+ self,
180
+ root: str,
181
+ sample_rate: int,
182
+ chunk_size_in_sec: int = 3,
183
+ effect_types: List[torch.nn.Module] = None,
184
+ mode: str = "train",
185
+ ):
186
+ super().__init__()
187
+ self.chunks = []
188
+ self.song_idx = []
189
+ self.root = Path(root)
190
+ self.chunk_size_in_sec = chunk_size_in_sec
191
+ self.sample_rate = sample_rate
192
+ self.mode = mode
193
+
194
+ mode_path = self.root / self.mode
195
+ self.files = sorted(list(mode_path.glob("./**/*.wav")))
196
+ for i, audio_file in enumerate(self.files):
197
+ chunk_starts, orig_sr = create_sequential_chunks(
198
+ audio_file, self.chunk_size_in_sec
199
+ )
200
+ self.chunks += chunk_starts
201
+ self.song_idx += [i] * len(chunk_starts)
202
+ print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
203
+ self.resampler = T.Resample(orig_sr, sample_rate)
204
+ self.effect_types = effect_types
205
+ self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
206
+
207
+ def __len__(self):
208
+ return len(self.chunks)
209
+
210
+ def __getitem__(self, idx):
211
+ # Load and effect audio
212
+ song_idx = self.song_idx[idx]
213
+ x, sr = torchaudio.load(self.files[song_idx])
214
+ chunk_start = self.chunks[idx]
215
+ chunk_size_in_samples = self.chunk_size_in_sec * sr
216
+ x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
217
+ resampled_x = self.resampler(x)
218
+ # Reset chunk size to be new sample rate
219
+ chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
220
+ # Pad to chunk_size if needed
221
+ if resampled_x.shape[-1] < chunk_size_in_samples:
222
+ resampled_x = F.pad(
223
+ resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
224
+ )
225
+
226
+ # Add random effect if train
227
+ if self.mode == "train":
228
+ random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
229
+ effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
230
+ effect = self.effect_types[effect_name]
231
+ effected_input = effect(resampled_x)
232
+ else:
233
+ # deterministic static effect for eval
234
+ effect_idx = idx % len(self.effect_types.keys())
235
+ effect_name = list(self.effect_types.keys())[effect_idx]
236
+ effect = deterministic_effects[effect_name]
237
+ effected_input = torch.from_numpy(
238
+ effect(resampled_x.numpy(), self.sample_rate)
239
+ )
240
+ normalized_input = self.normalize(effected_input)
241
+ normalized_target = self.normalize(resampled_x)
242
+ return (normalized_input, normalized_target, effect_name)
243
+
244
+
245
  def create_random_chunks(
246
  audio_file: str, chunk_size: int, num_chunks: int
247
  ) -> Tuple[List[Tuple[int, int]], int]:
 
318
  pin_memory=self.pin_memory,
319
  shuffle=False,
320
  )
321
+
322
+
323
+ class VocalSetDatamodule(pl.LightningDataModule):
324
+ def __init__(
325
+ self,
326
+ train_dataset,
327
+ val_dataset,
328
+ *,
329
+ batch_size: int,
330
+ num_workers: int,
331
+ pin_memory: bool = False,
332
+ **kwargs: int,
333
+ ) -> None:
334
+ super().__init__()
335
+ self.train_dataset = train_dataset
336
+ self.val_dataset = val_dataset
337
+ self.batch_size = batch_size
338
+ self.num_workers = num_workers
339
+ self.pin_memory = pin_memory
340
+
341
+ def setup(self, stage: Any = None) -> None:
342
+ pass
343
+
344
+ def train_dataloader(self) -> DataLoader:
345
+ return DataLoader(
346
+ dataset=self.train_dataset,
347
+ batch_size=self.batch_size,
348
+ num_workers=self.num_workers,
349
+ pin_memory=self.pin_memory,
350
+ shuffle=True,
351
+ )
352
+
353
+ def val_dataloader(self) -> DataLoader:
354
+ return DataLoader(
355
+ dataset=self.val_dataset,
356
+ batch_size=self.batch_size,
357
+ num_workers=self.num_workers,
358
+ pin_memory=self.pin_memory,
359
+ shuffle=False,
360
+ )
remfx/models.py CHANGED
@@ -7,11 +7,43 @@ from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
  from torch.nn import L1Loss
 
 
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
12
  from torchaudio.models import HDemucs
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class RemFXModel(pl.LightningModule):
16
  def __init__(
17
  self,
@@ -35,7 +67,7 @@ class RemFXModel(pl.LightningModule):
35
  {
36
  "SISDR": SISDRLoss(),
37
  "STFT": STFTLoss(),
38
- "L1": L1Loss(),
39
  }
40
  )
41
  # Log first batch metrics input vs output only once
 
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
  from torch.nn import L1Loss
10
+ from frechet_audio_distance import FrechetAudioDistance
11
+ import numpy as np
12
 
13
  from umx.openunmix.model import OpenUnmix, Separator
14
  from torchaudio.models import HDemucs
15
 
16
 
17
+ class FADLoss(torch.nn.Module):
18
+ def __init__(self, sample_rate: float):
19
+ super().__init__()
20
+ self.fad = FrechetAudioDistance(
21
+ use_pca=False, use_activation=False, verbose=False
22
+ )
23
+ self.sr = sample_rate
24
+
25
+ def forward(self, audio_background, audio_eval):
26
+ embds_background = []
27
+ embds_eval = []
28
+ for sample in audio_background:
29
+ embd = self.fad.model.forward(sample.T.detach().numpy(), self.sr)
30
+ embds_background.append(embd.cpu().detach().numpy())
31
+ for sample in audio_eval:
32
+ embd = self.fad.model.forward(sample.T.detach().numpy(), self.sr)
33
+ embds_eval.append(embd.cpu().detach().numpy())
34
+ embds_background = np.concatenate(embds_background, axis=0)
35
+ embds_eval = np.concatenate(embds_eval, axis=0)
36
+ mu_background, sigma_background = self.fad.calculate_embd_statistics(
37
+ embds_background
38
+ )
39
+ mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
40
+
41
+ fad_score = self.fad.calculate_frechet_distance(
42
+ mu_background, sigma_background, mu_eval, sigma_eval
43
+ )
44
+ return fad_score
45
+
46
+
47
  class RemFXModel(pl.LightningModule):
48
  def __init__(
49
  self,
 
67
  {
68
  "SISDR": SISDRLoss(),
69
  "STFT": STFTLoss(),
70
+ "FAD": FADLoss(sample_rate=sample_rate),
71
  }
72
  )
73
  # Log first batch metrics input vs output only once
shell_vars.sh CHANGED
@@ -1,3 +1,3 @@
1
- export DATASET_ROOT="./data/GuitarSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
+ export DATASET_ROOT="./data/VocalSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"