mattricesound commited on
Commit
7173e65
·
1 Parent(s): 78dbfc8

Render effected chunks to avoid bottlenecks

Browse files
README.md CHANGED
@@ -32,4 +32,7 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
32
  - `chorus`
33
  - `compressor`
34
  - `distortion`
35
- - `reverb`
 
 
 
 
32
  - `chorus`
33
  - `compressor`
34
  - `distortion`
35
+ - `reverb`
36
+
37
+ ## Misc.
38
+ To skip rendering files, add `+datamodule.train_dataset.render_files=False +datamodule.val_dataset.render_files=False` to the command-line
cfg/config.yaml CHANGED
@@ -26,6 +26,7 @@ datamodule:
26
  _target_: remfx.datasets.VocalSet
27
  sample_rate: ${sample_rate}
28
  root: ${oc.env:DATASET_ROOT}
 
29
  chunk_size_in_sec: 6
30
  mode: "train"
31
  effect_types: ${effects.train_effects}
@@ -33,6 +34,7 @@ datamodule:
33
  _target_: remfx.datasets.VocalSet
34
  sample_rate: ${sample_rate}
35
  root: ${oc.env:DATASET_ROOT}
 
36
  chunk_size_in_sec: 6
37
  mode: "val"
38
  effect_types: ${effects.val_effects}
 
26
  _target_: remfx.datasets.VocalSet
27
  sample_rate: ${sample_rate}
28
  root: ${oc.env:DATASET_ROOT}
29
+ output_root: ${oc.env:OUTPUT_ROOT}
30
  chunk_size_in_sec: 6
31
  mode: "train"
32
  effect_types: ${effects.train_effects}
 
34
  _target_: remfx.datasets.VocalSet
35
  sample_rate: ${sample_rate}
36
  root: ${oc.env:DATASET_ROOT}
37
+ output_root: ${oc.env:OUTPUT_ROOT}
38
  chunk_size_in_sec: 6
39
  mode: "val"
40
  effect_types: ${effects.val_effects}
cfg/effects/{compressor.yaml → compression.yaml} RENAMED
File without changes
cfg/effects/distortion.yaml CHANGED
@@ -10,5 +10,5 @@ effects:
10
  Distortion:
11
  _target_: remfx.effects.RandomPedalboardDistortion
12
  sample_rate: ${sample_rate}
13
- min_drive_db: 25
14
- max_drive_db: 25
 
10
  Distortion:
11
  _target_: remfx.effects.RandomPedalboardDistortion
12
  sample_rate: ${sample_rate}
13
+ min_drive_db: 30
14
+ max_drive_db: 30
remfx/datasets.py CHANGED
@@ -17,6 +17,7 @@ from pedalboard import (
17
  Distortion,
18
  Limiter,
19
  )
 
20
 
21
  # https://zenodo.org/record/7044411/ -> GuitarFX
22
  # https://zenodo.org/record/3371780 -> GuitarSet
@@ -181,6 +182,8 @@ class VocalSet(Dataset):
181
  sample_rate: int,
182
  chunk_size_in_sec: int = 3,
183
  effect_types: List[torch.nn.Module] = None,
 
 
184
  mode: str = "train",
185
  ):
186
  super().__init__()
@@ -193,49 +196,68 @@ class VocalSet(Dataset):
193
 
194
  mode_path = self.root / self.mode
195
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
196
- for i, audio_file in enumerate(self.files):
197
- chunk_starts, orig_sr = create_sequential_chunks(
198
- audio_file, self.chunk_size_in_sec
199
- )
200
- self.chunks += chunk_starts
201
- self.song_idx += [i] * len(chunk_starts)
202
- print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
203
- self.resampler = T.Resample(orig_sr, sample_rate)
204
- self.effect_types = effect_types
205
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
 
206
 
207
- def __len__(self):
208
- return len(self.chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- def __getitem__(self, idx):
211
- # Load and effect audio
212
- song_idx = self.song_idx[idx]
213
- x, sr = torchaudio.load(self.files[song_idx])
214
- chunk_start = self.chunks[idx]
215
- chunk_size_in_samples = self.chunk_size_in_sec * sr
216
- x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
217
- resampled_x = self.resampler(x)
218
- # Reset chunk size to be new sample rate
219
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
220
- # Pad to chunk_size if needed
221
- if resampled_x.shape[-1] < chunk_size_in_samples:
222
- resampled_x = F.pad(
223
- resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
224
- )
225
 
226
- # Add random effect if train
227
- if self.mode == "train":
228
- effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
229
- else:
230
- # deterministic effect for eval
231
- effect_idx = idx % len(self.effect_types.keys())
232
- effect_name = list(self.effect_types.keys())[int(effect_idx)]
233
- effect = self.effect_types[effect_name]
234
- effected_input = effect(resampled_x)
235
 
236
- normalized_input = self.normalize(effected_input)
237
- normalized_target = self.normalize(resampled_x)
238
- return (normalized_input, normalized_target, effect_name)
 
 
 
 
239
 
240
 
241
  def create_random_chunks(
@@ -262,10 +284,15 @@ def create_sequential_chunks(
262
  """Create sequential chunks of size chunk_size (seconds) from an audio file.
263
  Return sample_index of start of each chunk and original sr
264
  """
 
265
  audio, sr = torchaudio.load(audio_file)
266
  chunk_size_in_samples = chunk_size * sr
267
  chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
268
- return chunk_starts, sr
 
 
 
 
269
 
270
 
271
  class Datamodule(pl.LightningDataModule):
 
17
  Distortion,
18
  Limiter,
19
  )
20
+ from tqdm import tqdm
21
 
22
  # https://zenodo.org/record/7044411/ -> GuitarFX
23
  # https://zenodo.org/record/3371780 -> GuitarSet
 
182
  sample_rate: int,
183
  chunk_size_in_sec: int = 3,
184
  effect_types: List[torch.nn.Module] = None,
185
+ render_files: bool = True,
186
+ output_root: str = "processed",
187
  mode: str = "train",
188
  ):
189
  super().__init__()
 
196
 
197
  mode_path = self.root / self.mode
198
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
 
 
 
 
 
 
 
 
 
199
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
200
+ self.effect_types = effect_types
201
 
202
+ self.output_root = Path(output_root)
203
+ output_mode_path = output_root / self.mode
204
+
205
+ self.num_chunks = 0
206
+ print("Total files:", len(self.files))
207
+ print("Processing files...")
208
+ if render_files:
209
+ if not output_root.exists():
210
+ output_root.mkdir()
211
+ if not output_mode_path.exists():
212
+ output_mode_path.mkdir()
213
+ for i, audio_file in tqdm(enumerate(self.files)):
214
+ chunks, orig_sr = create_sequential_chunks(
215
+ audio_file, self.chunk_size_in_sec
216
+ )
217
+ for chunk in chunks:
218
+ resampled_chunk = torchaudio.functional.resample(
219
+ chunk, orig_sr, sample_rate
220
+ )
221
+ chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
222
+ if resampled_chunk.shape[-1] < chunk_size_in_samples:
223
+ resampled_chunk = F.pad(
224
+ resampled_chunk,
225
+ (0, chunk_size_in_samples - resampled_chunk.shape[1]),
226
+ )
227
+ effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
228
+ effect_name = list(self.effect_types.keys())[int(effect_idx)]
229
+ effect = self.effect_types[effect_name]
230
+ effected_input = effect(resampled_chunk)
231
+ normalized_input = self.normalize(effected_input)
232
+ normalized_target = self.normalize(resampled_chunk)
233
+
234
+ output_dir = output_mode_path / str(self.num_chunks)
235
+ output_dir.mkdir(exist_ok=True)
236
+ torchaudio.save(
237
+ output_dir / "input.wav", normalized_input, self.sample_rate
238
+ )
239
+ torchaudio.save(
240
+ output_dir / "target.wav", normalized_target, self.sample_rate
241
+ )
242
+ self.num_chunks += 1
243
+ else:
244
+ self.num_chunks = len(list(output_mode_path.glob("./**/*.wav")))
245
 
246
+ print(
247
+ f"Found {len(self.files)} {self.mode} files .\n"
248
+ f"Total chunks: {self.num_chunks}"
249
+ )
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ def __len__(self):
252
+ return self.num_chunks
 
 
 
 
 
 
 
253
 
254
+ def __getitem__(self, idx):
255
+ # Load audio
256
+ input_file = self.root / "processed" / self.mode / str(idx) / "input.wav"
257
+ target_file = self.root / "processed" / self.mode / str(idx) / "target.wav"
258
+ input, sr = torchaudio.load(input_file)
259
+ target, sr = torchaudio.load(target_file)
260
+ return (input, target, "")
261
 
262
 
263
  def create_random_chunks(
 
284
  """Create sequential chunks of size chunk_size (seconds) from an audio file.
285
  Return sample_index of start of each chunk and original sr
286
  """
287
+ chunks = []
288
  audio, sr = torchaudio.load(audio_file)
289
  chunk_size_in_samples = chunk_size * sr
290
  chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
291
+ for start in chunk_starts:
292
+ if start + chunk_size_in_samples > audio.shape[-1]:
293
+ break
294
+ chunks.append(audio[:, start : start + chunk_size_in_samples])
295
+ return chunks, sr
296
 
297
 
298
  class Datamodule(pl.LightningDataModule):
remfx/models.py CHANGED
@@ -20,6 +20,7 @@ class FADLoss(torch.nn.Module):
20
  self.fad = FrechetAudioDistance(
21
  use_pca=False, use_activation=False, verbose=False
22
  )
 
23
  self.sr = sample_rate
24
 
25
  def forward(self, audio_background, audio_eval):
 
20
  self.fad = FrechetAudioDistance(
21
  use_pca=False, use_activation=False, verbose=False
22
  )
23
+ self.fad.model = self.fad.model.to("cpu")
24
  self.sr = sample_rate
25
 
26
  def forward(self, audio_background, audio_eval):
shell_vars.sh CHANGED
@@ -1,3 +1,4 @@
1
  export DATASET_ROOT="./data/VocalSet"
 
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
  export DATASET_ROOT="./data/VocalSet"
2
+ export DATASET_ROOT="/scratch/VocalSet/processed"
3
  export WANDB_PROJECT="RemFX"
4
  export WANDB_ENTITY="mattricesound"