Christian J. Steinmetz commited on
Commit
279d167
·
1 Parent(s): 1b6bb59

adding support for parallel processing in dataset generation

Browse files
Files changed (1) hide show
  1. remfx/datasets.py +144 -22
remfx/datasets.py CHANGED
@@ -8,15 +8,16 @@ import pytorch_lightning as pl
8
  import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
- from remfx import effects
12
  from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
 
15
 
16
 
17
  # https://zenodo.org/record/1193957 -> VocalSet
18
 
19
- ALL_EFFECTS = effects.Pedalboard_Effects
20
  # print(ALL_EFFECTS)
21
 
22
 
@@ -146,6 +147,101 @@ def locate_files(root: str, mode: str):
146
  return file_list
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class EffectDataset(Dataset):
150
  def __init__(
151
  self,
@@ -163,6 +259,7 @@ class EffectDataset(Dataset):
163
  render_files: bool = True,
164
  render_root: str = None,
165
  mode: str = "train",
 
166
  ):
167
  super().__init__()
168
  self.chunks = []
@@ -177,7 +274,7 @@ class EffectDataset(Dataset):
177
  self.num_removed_effects = num_removed_effects
178
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
179
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
180
- self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
181
  self.effects = effect_modules
182
  self.shuffle_kept_effects = shuffle_kept_effects
183
  self.shuffle_removed_effects = shuffle_removed_effects
@@ -192,6 +289,7 @@ class EffectDataset(Dataset):
192
  )
193
  self.validate_effect_input()
194
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
 
195
 
196
  self.files = locate_files(self.root, self.mode)
197
 
@@ -212,26 +310,50 @@ class EffectDataset(Dataset):
212
  if render_files:
213
  # Split audio file into chunks, resample, then apply random effects
214
  self.proc_root.mkdir(parents=True, exist_ok=True)
215
- for num_chunk in tqdm(range(self.total_chunks)):
216
- chunk = None
217
- random_dataset_choice = random.choice(self.files)
218
- while chunk is None:
219
- random_file_choice = random.choice(random_dataset_choice)
220
- chunk = select_random_chunk(
221
- random_file_choice, self.chunk_size, self.sample_rate
222
- )
223
 
224
- # Sum to mono
225
- if chunk.shape[0] > 1:
226
- chunk = chunk.sum(0, keepdim=True)
227
-
228
- dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
229
- output_dir = self.proc_root / str(num_chunk)
230
- output_dir.mkdir(exist_ok=True)
231
- torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
232
- torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
233
- torch.save(dry_effects, output_dir / "dry_effects.pt")
234
- torch.save(wet_effects, output_dir / "wet_effects.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  print("Finished rendering")
237
  else:
 
8
  import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
+ from remfx import effects as effect_lib
12
  from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
+ import multiprocessing
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
+ ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
23
 
 
147
  return file_list
148
 
149
 
150
+ def parallel_process_effects(
151
+ chunk_idx: int,
152
+ proc_root: str,
153
+ files: list,
154
+ chunk_size: int,
155
+ effects: list,
156
+ effects_to_keep: list,
157
+ num_kept_effects: tuple,
158
+ shuffle_kept_effects: bool,
159
+ effects_to_remove: list,
160
+ num_removed_effects: tuple,
161
+ shuffle_removed_effects: bool,
162
+ sample_rate: int,
163
+ target_lufs_db: float,
164
+ ):
165
+ chunk = None
166
+ random_dataset_choice = random.choice(files)
167
+ while chunk is None:
168
+ random_file_choice = random.choice(random_dataset_choice)
169
+ chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate)
170
+
171
+ # Sum to mono
172
+ if chunk.shape[0] > 1:
173
+ chunk = chunk.sum(0, keepdim=True)
174
+
175
+ dry = chunk
176
+
177
+ # loudness normalization
178
+ normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db)
179
+
180
+ # Apply Kept Effects
181
+ # Shuffle effects if specified
182
+ if shuffle_kept_effects:
183
+ effect_indices = torch.randperm(len(effects_to_keep))
184
+ else:
185
+ effect_indices = torch.arange(len(effects_to_keep))
186
+
187
+ r1 = num_kept_effects[0]
188
+ r2 = num_kept_effects[1]
189
+ num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
190
+ effect_indices = effect_indices[:num_kept_effects]
191
+ # Index in effect settings
192
+ effect_names_to_apply = [effects_to_keep[i] for i in effect_indices]
193
+ effects_to_apply = [effects[i] for i in effect_names_to_apply]
194
+ # Apply
195
+ dry_labels = []
196
+ for effect in effects_to_apply:
197
+ # Normalize in-between effects
198
+ dry = normalize(effect(dry))
199
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
200
+
201
+ # Apply effects_to_remove
202
+ # Shuffle effects if specified
203
+ if shuffle_removed_effects:
204
+ effect_indices = torch.randperm(len(effects_to_remove))
205
+ else:
206
+ effect_indices = torch.arange(len(effects_to_remove))
207
+ wet = torch.clone(dry)
208
+ r1 = num_removed_effects[0]
209
+ r2 = num_removed_effects[1]
210
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
211
+ effect_indices = effect_indices[:num_removed_effects]
212
+ # Index in effect settings
213
+ effect_names_to_apply = [effects_to_remove[i] for i in effect_indices]
214
+ effects_to_apply = [effects[i] for i in effect_names_to_apply]
215
+ # Apply
216
+ wet_labels = []
217
+ for effect in effects_to_apply:
218
+ # Normalize in-between effects
219
+ wet = normalize(effect(wet))
220
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
221
+
222
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
223
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
224
+
225
+ for label_idx in wet_labels:
226
+ wet_labels_tensor[label_idx] = 1.0
227
+
228
+ for label_idx in dry_labels:
229
+ dry_labels_tensor[label_idx] = 1.0
230
+
231
+ # Normalize
232
+ normalized_dry = normalize(dry)
233
+ normalized_wet = normalize(wet)
234
+
235
+ output_dir = proc_root / str(chunk_idx)
236
+ output_dir.mkdir(exist_ok=True)
237
+ torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate)
238
+ torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate)
239
+ torch.save(dry_labels_tensor, output_dir / "dry_effects.pt")
240
+ torch.save(wet_labels_tensor, output_dir / "wet_effects.pt")
241
+
242
+ # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
243
+
244
+
245
  class EffectDataset(Dataset):
246
  def __init__(
247
  self,
 
259
  render_files: bool = True,
260
  render_root: str = None,
261
  mode: str = "train",
262
+ parallel: bool = True,
263
  ):
264
  super().__init__()
265
  self.chunks = []
 
274
  self.num_removed_effects = num_removed_effects
275
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
276
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
277
+ self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
278
  self.effects = effect_modules
279
  self.shuffle_kept_effects = shuffle_kept_effects
280
  self.shuffle_removed_effects = shuffle_removed_effects
 
289
  )
290
  self.validate_effect_input()
291
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
292
+ self.parallel = parallel
293
 
294
  self.files = locate_files(self.root, self.mode)
295
 
 
310
  if render_files:
311
  # Split audio file into chunks, resample, then apply random effects
312
  self.proc_root.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
313
 
314
+ if self.parallel:
315
+ items = [
316
+ (
317
+ chunk_idx,
318
+ self.proc_root,
319
+ self.files,
320
+ self.chunk_size,
321
+ self.effects,
322
+ self.effects_to_keep,
323
+ self.num_kept_effects,
324
+ self.shuffle_kept_effects,
325
+ self.effects_to_remove,
326
+ self.num_removed_effects,
327
+ self.shuffle_removed_effects,
328
+ self.sample_rate,
329
+ -20.0,
330
+ )
331
+ for chunk_idx in range(self.total_chunks)
332
+ ]
333
+ with multiprocessing.Pool(processes=32) as pool:
334
+ pool.starmap(parallel_process_effects, items)
335
+ print(f"Done proccessing {self.total_chunks}", flush=True)
336
+ else:
337
+ for num_chunk in tqdm(range(self.total_chunks)):
338
+ chunk = None
339
+ random_dataset_choice = random.choice(self.files)
340
+ while chunk is None:
341
+ random_file_choice = random.choice(random_dataset_choice)
342
+ chunk = select_random_chunk(
343
+ random_file_choice, self.chunk_size, self.sample_rate
344
+ )
345
+
346
+ # Sum to mono
347
+ if chunk.shape[0] > 1:
348
+ chunk = chunk.sum(0, keepdim=True)
349
+
350
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
351
+ output_dir = self.proc_root / str(num_chunk)
352
+ output_dir.mkdir(exist_ok=True)
353
+ torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
354
+ torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
355
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
356
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
357
 
358
  print("Finished rendering")
359
  else: