mattricesound commited on
Commit
848b108
·
2 Parent(s): 3d26e07 7f36717

Merge pull request #24 from mhrice/new-metrics

Browse files
README.md CHANGED
@@ -22,7 +22,7 @@ Models and effects detailed below.
22
 
23
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
 
25
- Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=-1`
26
 
27
  ### Current Models
28
  - `umx`
 
22
 
23
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
 
25
+ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
26
 
27
  ### Current Models
28
  - `umx`
cfg/config.yaml CHANGED
@@ -6,8 +6,8 @@ defaults:
6
  seed: 12345
7
  train: True
8
  sample_rate: 48000
 
9
  logs_dir: "./logs"
10
- log_every_n_steps: 1000
11
  render_files: True
12
  render_root: "./data/processed"
13
 
@@ -21,6 +21,9 @@ callbacks:
21
  verbose: False
22
  dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
23
  filename: '{epoch:02d}-{valid_loss:.3f}'
 
 
 
24
 
25
  datamodule:
26
  _target_: remfx.datasets.VocalSetDatamodule
@@ -28,27 +31,27 @@ datamodule:
28
  _target_: remfx.datasets.VocalSet
29
  sample_rate: ${sample_rate}
30
  root: ${oc.env:DATASET_ROOT}
31
- chunk_size_in_sec: 6
32
  mode: "train"
33
- effect_types: ${effects.train_effects}
34
  render_files: ${render_files}
35
  render_root: ${render_root}
36
  val_dataset:
37
  _target_: remfx.datasets.VocalSet
38
  sample_rate: ${sample_rate}
39
  root: ${oc.env:DATASET_ROOT}
40
- chunk_size_in_sec: 6
41
  mode: "val"
42
- effect_types: ${effects.val_effects}
43
  render_files: ${render_files}
44
  render_root: ${render_root}
45
  test_dataset:
46
  _target_: remfx.datasets.VocalSet
47
  sample_rate: ${sample_rate}
48
  root: ${oc.env:DATASET_ROOT}
49
- chunk_size_in_sec: 6
50
  mode: "test"
51
- effect_types: ${effects.val_effects}
52
  render_files: ${render_files}
53
  render_root: ${render_root}
54
 
@@ -76,3 +79,5 @@ trainer:
76
  accumulate_grad_batches: 1
77
  accelerator: null
78
  devices: 1
 
 
 
6
  seed: 12345
7
  train: True
8
  sample_rate: 48000
9
+ chunk_size: 262144 # 5.5s
10
  logs_dir: "./logs"
 
11
  render_files: True
12
  render_root: "./data/processed"
13
 
 
21
  verbose: False
22
  dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
23
  filename: '{epoch:02d}-{valid_loss:.3f}'
24
+ learning_rate_monitor:
25
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
26
+ logging_interval: "step"
27
 
28
  datamodule:
29
  _target_: remfx.datasets.VocalSetDatamodule
 
31
  _target_: remfx.datasets.VocalSet
32
  sample_rate: ${sample_rate}
33
  root: ${oc.env:DATASET_ROOT}
34
+ chunk_size: ${chunk_size}
35
  mode: "train"
36
+ effect_types: ${effects}
37
  render_files: ${render_files}
38
  render_root: ${render_root}
39
  val_dataset:
40
  _target_: remfx.datasets.VocalSet
41
  sample_rate: ${sample_rate}
42
  root: ${oc.env:DATASET_ROOT}
43
+ chunk_size: ${chunk_size}
44
  mode: "val"
45
+ effect_types: ${effects}
46
  render_files: ${render_files}
47
  render_root: ${render_root}
48
  test_dataset:
49
  _target_: remfx.datasets.VocalSet
50
  sample_rate: ${sample_rate}
51
  root: ${oc.env:DATASET_ROOT}
52
+ chunk_size: ${chunk_size}
53
  mode: "test"
54
+ effect_types: ${effects}
55
  render_files: ${render_files}
56
  render_root: ${render_root}
57
 
 
79
  accumulate_grad_batches: 1
80
  accelerator: null
81
  devices: 1
82
+ gradient_clip_val: 10.0
83
+ max_steps: 50000
cfg/effects/all.yaml CHANGED
@@ -1,70 +1,31 @@
1
  # @package _global_
2
  effects:
3
- train_effects:
4
- Chorus:
5
- _target_: remfx.effects.RandomPedalboardChorus
6
- sample_rate: ${sample_rate}
7
- Distortion:
8
- _target_: remfx.effects.RandomPedalboardDistortion
9
- sample_rate: ${sample_rate}
10
- min_drive_db: -10
11
- max_drive_db: 50
12
- Compressor:
13
- _target_: remfx.effects.RandomPedalboardCompressor
14
- sample_rate: ${sample_rate}
15
- min_threshold_db: -42.0
16
- max_threshold_db: -20.0
17
- min_ratio: 1.5
18
- max_ratio: 6.0
19
- Reverb:
20
- _target_: remfx.effects.RandomPedalboardReverb
21
- sample_rate: ${sample_rate}
22
- min_room_size: 0.3
23
- max_room_size: 1.0
24
- min_damping: 0.2
25
- max_damping: 1.0
26
- min_wet_dry: 0.2
27
- max_wet_dry: 0.8
28
- min_width: 0.2
29
- max_width: 1.0
30
- val_effects:
31
- Chorus:
32
- _target_: remfx.effects.RandomPedalboardChorus
33
- sample_rate: ${sample_rate}
34
- min_rate_hz: 1.0
35
- max_rate_hz: 1.0
36
- min_depth: 0.3
37
- max_depth: 0.3
38
- min_centre_delay_ms: 7.5
39
- max_centre_delay_ms: 7.5
40
- min_feedback: 0.4
41
- max_feedback: 0.4
42
- min_mix: 0.4
43
- max_mix: 0.4
44
- Distortion:
45
- _target_: remfx.effects.RandomPedalboardDistortion
46
- sample_rate: ${sample_rate}
47
- min_drive_db: 30
48
- max_drive_db: 30
49
- Compressor:
50
- _target_: remfx.effects.RandomPedalboardCompressor
51
- sample_rate: ${sample_rate}
52
- min_threshold_db: -32
53
- max_threshold_db: -32
54
- min_ratio: 3.0
55
- max_ratio: 3.0
56
- min_attack_ms: 10.0
57
- max_attack_ms: 10.0
58
- min_release_ms: 40.0
59
- max_release_ms: 40.0
60
- Reverb:
61
- _target_: remfx.effects.RandomPedalboardReverb
62
- sample_rate: ${sample_rate}
63
- min_room_size: 0.5
64
- max_room_size: 0.5
65
- min_damping: 0.5
66
- max_damping: 0.5
67
- min_wet_dry: 0.4
68
- max_wet_dry: 0.4
69
- min_width: 0.5
70
- max_width: 0.5
 
1
  # @package _global_
2
  effects:
3
+ Chorus:
4
+ _target_: remfx.effects.RandomPedalboardChorus
5
+ sample_rate: ${sample_rate}
6
+ min_depth: 0.2
7
+ min_mix: 0.3
8
+ Distortion:
9
+ _target_: remfx.effects.RandomPedalboardDistortion
10
+ sample_rate: ${sample_rate}
11
+ min_drive_db: 10
12
+ max_drive_db: 50
13
+ Compressor:
14
+ _target_: remfx.effects.RandomPedalboardCompressor
15
+ sample_rate: ${sample_rate}
16
+ min_threshold_db: -42.0
17
+ max_threshold_db: -20.0
18
+ min_ratio: 1.5
19
+ max_ratio: 6.0
20
+ Reverb:
21
+ _target_: remfx.effects.RandomPedalboardReverb
22
+ sample_rate: ${sample_rate}
23
+ min_room_size: 0.3
24
+ max_room_size: 1.0
25
+ min_damping: 0.2
26
+ max_damping: 1.0
27
+ min_wet_dry: 0.2
28
+ max_wet_dry: 0.8
29
+ min_width: 0.2
30
+ max_width: 1.0
31
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effects/chorus.yaml CHANGED
@@ -1,20 +1,7 @@
1
  # @package _global_
2
  effects:
3
- train_effects:
4
- Chorus:
5
- _target_: remfx.effects.RandomPedalboardChorus
6
- sample_rate: ${sample_rate}
7
- val_effects:
8
- Chorus:
9
- _target_: remfx.effects.RandomPedalboardChorus
10
- sample_rate: ${sample_rate}
11
- min_rate_hz: 1.0
12
- max_rate_hz: 1.0
13
- min_depth: 0.3
14
- max_depth: 0.3
15
- min_centre_delay_ms: 7.5
16
- max_centre_delay_ms: 7.5
17
- min_feedback: 0.4
18
- max_feedback: 0.4
19
- min_mix: 0.4
20
- max_mix: 0.4
 
1
  # @package _global_
2
  effects:
3
+ Chorus:
4
+ _target_: remfx.effects.RandomPedalboardChorus
5
+ sample_rate: ${sample_rate}
6
+ min_depth: 0.2
7
+ min_mix: 0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effects/compression.yaml DELETED
@@ -1,22 +0,0 @@
1
- # @package _global_
2
- effects:
3
- train_effects:
4
- Compressor:
5
- _target_: remfx.effects.RandomPedalboardCompressor
6
- sample_rate: ${sample_rate}
7
- min_threshold_db: -42.0
8
- max_threshold_db: -20.0
9
- min_ratio: 1.5
10
- max_ratio: 6.0
11
- val_effects:
12
- Compressor:
13
- _target_: remfx.effects.RandomPedalboardCompressor
14
- sample_rate: ${sample_rate}
15
- min_threshold_db: -32
16
- max_threshold_db: -32
17
- min_ratio: 3.0
18
- max_ratio: 3.0
19
- min_attack_ms: 10.0
20
- max_attack_ms: 10.0
21
- min_release_ms: 40.0
22
- max_release_ms: 40.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effects/compressor.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ Compressor:
4
+ _target_: remfx.effects.RandomPedalboardCompressor
5
+ sample_rate: ${sample_rate}
6
+ min_threshold_db: -42.0
7
+ max_threshold_db: -20.0
8
+ min_ratio: 1.5
9
+ max_ratio: 6.0
cfg/effects/distortion.yaml CHANGED
@@ -1,14 +1,7 @@
1
  # @package _global_
2
  effects:
3
- train_effects:
4
- Distortion:
5
- _target_: remfx.effects.RandomPedalboardDistortion
6
- sample_rate: ${sample_rate}
7
- min_drive_db: -10
8
- max_drive_db: 50
9
- val_effects:
10
- Distortion:
11
- _target_: remfx.effects.RandomPedalboardDistortion
12
- sample_rate: ${sample_rate}
13
- min_drive_db: 30
14
- max_drive_db: 30
 
1
  # @package _global_
2
  effects:
3
+ Distortion:
4
+ _target_: remfx.effects.RandomPedalboardDistortion
5
+ sample_rate: ${sample_rate}
6
+ min_drive_db: 10
7
+ max_drive_db: 50
 
 
 
 
 
 
 
cfg/effects/reverb.yaml CHANGED
@@ -1,26 +1,13 @@
1
  # @package _global_
2
  effects:
3
- train_effects:
4
- Reverb:
5
- _target_: remfx.effects.RandomPedalboardReverb
6
- sample_rate: ${sample_rate}
7
- min_room_size: 0.3
8
- max_room_size: 1.0
9
- min_damping: 0.2
10
- max_damping: 1.0
11
- min_wet_dry: 0.2
12
- max_wet_dry: 0.8
13
- min_width: 0.2
14
- max_width: 1.0
15
- val_effects:
16
- Reverb:
17
- _target_: remfx.effects.RandomPedalboardReverb
18
- sample_rate: ${sample_rate}
19
- min_room_size: 0.5
20
- max_room_size: 0.5
21
- min_damping: 0.5
22
- max_damping: 0.5
23
- min_wet_dry: 0.4
24
- max_wet_dry: 0.4
25
- min_width: 0.5
26
- max_width: 0.5
 
1
  # @package _global_
2
  effects:
3
+ Reverb:
4
+ _target_: remfx.effects.RandomPedalboardReverb
5
+ sample_rate: ${sample_rate}
6
+ min_room_size: 0.3
7
+ max_room_size: 1.0
8
+ min_damping: 0.2
9
+ max_damping: 1.0
10
+ min_wet_dry: 0.2
11
+ max_wet_dry: 0.8
12
+ min_width: 0.2
13
+ max_width: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/exp/{demucs_compression.yaml → demucs_compressor.yaml} RENAMED
@@ -1,4 +1,4 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
- - override /effects: compression
 
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
+ - override /effects: compressor
cfg/exp/{umx_compression.yaml → umx_compressor.yaml} RENAMED
@@ -1,4 +1,4 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
- - override /effects: compression
 
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
+ - override /effects: compressor
remfx/datasets.py CHANGED
@@ -17,7 +17,7 @@ class VocalSet(Dataset):
17
  self,
18
  root: str,
19
  sample_rate: int,
20
- chunk_size_in_sec: int = 3,
21
  effect_types: List[torch.nn.Module] = None,
22
  render_files: bool = True,
23
  render_root: str = None,
@@ -28,7 +28,7 @@ class VocalSet(Dataset):
28
  self.song_idx = []
29
  self.root = Path(root)
30
  self.render_root = Path(render_root)
31
- self.chunk_size_in_sec = chunk_size_in_sec
32
  self.sample_rate = sample_rate
33
  self.mode = mode
34
 
@@ -36,9 +36,11 @@ class VocalSet(Dataset):
36
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
37
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
38
  self.effect_types = effect_types
39
-
40
- self.processed_root = self.render_root / "processed" / self.mode
41
-
 
 
42
  self.num_chunks = 0
43
  print("Total files:", len(self.files))
44
  print("Processing files...")
@@ -46,19 +48,14 @@ class VocalSet(Dataset):
46
  # Split audio file into chunks, resample, then apply random effects
47
  self.processed_root.mkdir(parents=True, exist_ok=True)
48
  for audio_file in tqdm(self.files, total=len(self.files)):
49
- chunks, orig_sr = create_sequential_chunks(
50
- audio_file, self.chunk_size_in_sec
51
- )
52
  for chunk in chunks:
53
  resampled_chunk = torchaudio.functional.resample(
54
  chunk, orig_sr, sample_rate
55
  )
56
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
57
- if resampled_chunk.shape[-1] < chunk_size_in_samples:
58
- resampled_chunk = F.pad(
59
- resampled_chunk,
60
- (0, chunk_size_in_samples - resampled_chunk.shape[1]),
61
- )
62
  # Apply effect
63
  effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
64
  effect_name = list(self.effect_types.keys())[int(effect_idx)]
 
17
  self,
18
  root: str,
19
  sample_rate: int,
20
+ chunk_size: int = 3,
21
  effect_types: List[torch.nn.Module] = None,
22
  render_files: bool = True,
23
  render_root: str = None,
 
28
  self.song_idx = []
29
  self.root = Path(root)
30
  self.render_root = Path(render_root)
31
+ self.chunk_size = chunk_size
32
  self.sample_rate = sample_rate
33
  self.mode = mode
34
 
 
36
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
37
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
38
  self.effect_types = effect_types
39
+ effect_str = "_".join([e for e in self.effect_types])
40
+ self.processed_root = self.render_root / "processed" / effect_str / self.mode
41
+ if self.processed_root.exists():
42
+ print("Found processed files.")
43
+ render_files = False
44
  self.num_chunks = 0
45
  print("Total files:", len(self.files))
46
  print("Processing files...")
 
48
  # Split audio file into chunks, resample, then apply random effects
49
  self.processed_root.mkdir(parents=True, exist_ok=True)
50
  for audio_file in tqdm(self.files, total=len(self.files)):
51
+ chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
 
 
52
  for chunk in chunks:
53
  resampled_chunk = torchaudio.functional.resample(
54
  chunk, orig_sr, sample_rate
55
  )
56
+ if resampled_chunk.shape[-1] < chunk_size:
57
+ # Skip if chunk is too small
58
+ continue
 
 
 
59
  # Apply effect
60
  effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
61
  effect_name = list(self.effect_types.keys())[int(effect_idx)]
remfx/models.py CHANGED
@@ -55,6 +55,29 @@ class RemFXModel(pl.LightningModule):
55
  )
56
  return optimizer
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def training_step(self, batch, batch_idx):
59
  loss = self.common_step(batch, batch_idx, mode="train")
60
  return loss
@@ -215,7 +238,7 @@ class OpenUnmixModel(torch.nn.Module):
215
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
216
  Y = self.model(X)
217
  sep_out = self.separator(x).squeeze(1)
218
- loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
219
 
220
  return loss, sep_out
221
 
@@ -236,7 +259,7 @@ class DemucsModel(torch.nn.Module):
236
  def forward(self, batch):
237
  x, target, label = batch
238
  output = self.model(x).squeeze(1)
239
- loss = self.mrstftloss(output, target) + self.l1loss(output, target)
240
  return loss, output
241
 
242
  def sample(self, x: Tensor) -> Tensor:
@@ -264,10 +287,13 @@ def log_wandb_audio_batch(
264
  samples: Tensor,
265
  sampling_rate: int,
266
  caption: str = "",
 
267
  ):
268
  num_items = samples.shape[0]
269
  samples = rearrange(samples, "b c t -> b t c")
270
  for idx in range(num_items):
 
 
271
  logger.experiment.log(
272
  {
273
  f"{id}_{idx}": wandb.Audio(
 
55
  )
56
  return optimizer
57
 
58
+ # Add step-based learning rate scheduler
59
+ def optimizer_step(
60
+ self,
61
+ epoch,
62
+ batch_idx,
63
+ optimizer,
64
+ optimizer_idx,
65
+ optimizer_closure,
66
+ on_tpu,
67
+ using_native_amp,
68
+ using_lbfgs,
69
+ ):
70
+ # update params
71
+ optimizer.step(closure=optimizer_closure)
72
+
73
+ # update learning rate. Reduce by factor of 10 at 80% and 95% of training
74
+ if self.trainer.global_step == 0.8 * self.trainer.max_steps:
75
+ for pg in optimizer.param_groups:
76
+ pg["lr"] = 0.1 * pg["lr"]
77
+ if self.trainer.global_step == 0.95 * self.trainer.max_steps:
78
+ for pg in optimizer.param_groups:
79
+ pg["lr"] = 0.1 * pg["lr"]
80
+
81
  def training_step(self, batch, batch_idx):
82
  loss = self.common_step(batch, batch_idx, mode="train")
83
  return loss
 
238
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
239
  Y = self.model(X)
240
  sep_out = self.separator(x).squeeze(1)
241
+ loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
242
 
243
  return loss, sep_out
244
 
 
259
  def forward(self, batch):
260
  x, target, label = batch
261
  output = self.model(x).squeeze(1)
262
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
263
  return loss, output
264
 
265
  def sample(self, x: Tensor) -> Tensor:
 
287
  samples: Tensor,
288
  sampling_rate: int,
289
  caption: str = "",
290
+ max_items: int = 10,
291
  ):
292
  num_items = samples.shape[0]
293
  samples = rearrange(samples, "b c t -> b t c")
294
  for idx in range(num_items):
295
+ if idx >= max_items:
296
+ break
297
  logger.experiment.log(
298
  {
299
  f"{id}_{idx}": wandb.Audio(
remfx/utils.py CHANGED
@@ -132,10 +132,9 @@ def create_sequential_chunks(
132
  """
133
  chunks = []
134
  audio, sr = torchaudio.load(audio_file)
135
- chunk_size_in_samples = chunk_size * sr
136
- chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
137
  for start in chunk_starts:
138
- if start + chunk_size_in_samples > audio.shape[-1]:
139
  break
140
- chunks.append(audio[:, start : start + chunk_size_in_samples])
141
  return chunks, sr
 
132
  """
133
  chunks = []
134
  audio, sr = torchaudio.load(audio_file)
135
+ chunk_starts = torch.arange(0, audio.shape[-1], chunk_size)
 
136
  for start in chunk_starts:
137
+ if start + chunk_size > audio.shape[-1]:
138
  break
139
+ chunks.append(audio[:, start : start + chunk_size])
140
  return chunks, sr
scripts/test.py CHANGED
@@ -14,7 +14,6 @@ def main(cfg: DictConfig):
14
  # Apply seed for reproducibility
15
  if cfg.seed:
16
  pl.seed_everything(cfg.seed)
17
- cfg.render_files = False
18
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
19
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
20
  log.info(f"Instantiating model <{cfg.model._target_}>.")
 
14
  # Apply seed for reproducibility
15
  if cfg.seed:
16
  pl.seed_everything(cfg.seed)
 
17
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
18
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
19
  log.info(f"Instantiating model <{cfg.model._target_}>.")
scripts/train.py CHANGED
@@ -42,6 +42,7 @@ def main(cfg: DictConfig):
42
  summary = ModelSummary(model)
43
  print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
 
45
 
46
 
47
  if __name__ == "__main__":
 
42
  summary = ModelSummary(model)
43
  print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
45
+ trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
46
 
47
 
48
  if __name__ == "__main__":