Christian J. Steinmetz commited on
Commit
507048e
·
1 Parent(s): 29f23c3

adding a dynamic dataset for on the fly generation of training data

Browse files
Files changed (1) hide show
  1. remfx/datasets.py +137 -6
remfx/datasets.py CHANGED
@@ -162,6 +162,7 @@ def parallel_process_effects(
162
  sample_rate: int,
163
  target_lufs_db: float,
164
  ):
 
165
  chunk = None
166
  random_dataset_choice = random.choice(files)
167
  while chunk is None:
@@ -242,6 +243,134 @@ def parallel_process_effects(
242
  # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  class EffectDataset(Dataset):
246
  def __init__(
247
  self,
@@ -259,7 +388,7 @@ class EffectDataset(Dataset):
259
  render_files: bool = True,
260
  render_root: str = None,
261
  mode: str = "train",
262
- parallel: bool = True,
263
  ):
264
  super().__init__()
265
  self.chunks = []
@@ -524,7 +653,8 @@ class EffectDatamodule(pl.LightningDataModule):
524
  val_dataset,
525
  test_dataset,
526
  *,
527
- batch_size: int,
 
528
  num_workers: int,
529
  pin_memory: bool = False,
530
  **kwargs: int,
@@ -533,7 +663,8 @@ class EffectDatamodule(pl.LightningDataModule):
533
  self.train_dataset = train_dataset
534
  self.val_dataset = val_dataset
535
  self.test_dataset = test_dataset
536
- self.batch_size = batch_size
 
537
  self.num_workers = num_workers
538
  self.pin_memory = pin_memory
539
 
@@ -543,7 +674,7 @@ class EffectDatamodule(pl.LightningDataModule):
543
  def train_dataloader(self) -> DataLoader:
544
  return DataLoader(
545
  dataset=self.train_dataset,
546
- batch_size=self.batch_size,
547
  num_workers=self.num_workers,
548
  pin_memory=self.pin_memory,
549
  shuffle=True,
@@ -552,7 +683,7 @@ class EffectDatamodule(pl.LightningDataModule):
552
  def val_dataloader(self) -> DataLoader:
553
  return DataLoader(
554
  dataset=self.val_dataset,
555
- batch_size=self.batch_size,
556
  num_workers=self.num_workers,
557
  pin_memory=self.pin_memory,
558
  shuffle=False,
@@ -561,7 +692,7 @@ class EffectDatamodule(pl.LightningDataModule):
561
  def test_dataloader(self) -> DataLoader:
562
  return DataLoader(
563
  dataset=self.test_dataset,
564
- batch_size=2, # Use small, consistent batch size for testing
565
  num_workers=self.num_workers,
566
  pin_memory=self.pin_memory,
567
  shuffle=False,
 
162
  sample_rate: int,
163
  target_lufs_db: float,
164
  ):
165
+ """Note: This function has an issue with random seed. It may not fully randomize the effects."""
166
  chunk = None
167
  random_dataset_choice = random.choice(files)
168
  while chunk is None:
 
243
  # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
244
 
245
 
246
+ class DynamicEffectDataset(Dataset):
247
+ def __init__(
248
+ self,
249
+ root: str,
250
+ sample_rate: int,
251
+ chunk_size: int = 262144,
252
+ total_chunks: int = 1000,
253
+ effect_modules: List[Dict[str, torch.nn.Module]] = None,
254
+ effects_to_keep: List[str] = None,
255
+ effects_to_remove: List[str] = None,
256
+ num_kept_effects: List[int] = [1, 5],
257
+ num_removed_effects: List[int] = [1, 5],
258
+ shuffle_kept_effects: bool = True,
259
+ shuffle_removed_effects: bool = False,
260
+ render_files: bool = True,
261
+ render_root: str = None,
262
+ mode: str = "train",
263
+ parallel: bool = False,
264
+ ) -> None:
265
+ super().__init__()
266
+ self.chunks = []
267
+ self.song_idx = []
268
+ self.root = Path(root)
269
+ self.render_root = Path(render_root)
270
+ self.chunk_size = chunk_size
271
+ self.total_chunks = total_chunks
272
+ self.sample_rate = sample_rate
273
+ self.mode = mode
274
+ self.num_kept_effects = num_kept_effects
275
+ self.num_removed_effects = num_removed_effects
276
+ self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
277
+ self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
278
+ self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
279
+ self.effects = effect_modules
280
+ self.shuffle_kept_effects = shuffle_kept_effects
281
+ self.shuffle_removed_effects = shuffle_removed_effects
282
+ effects_string = "_".join(
283
+ self.effects_to_keep
284
+ + ["_"]
285
+ + self.effects_to_remove
286
+ + ["_"]
287
+ + [str(x) for x in num_kept_effects]
288
+ + ["_"]
289
+ + [str(x) for x in num_removed_effects]
290
+ )
291
+ # self.validate_effect_input()
292
+ # self.proc_root = self.render_root / "processed" / effects_string / self.mode
293
+ self.parallel = parallel
294
+ self.files = locate_files(self.root, self.mode)
295
+
296
+ def process_effects(self, dry: torch.Tensor):
297
+ # Apply Kept Effects
298
+ # Shuffle effects if specified
299
+ if self.shuffle_kept_effects:
300
+ effect_indices = torch.randperm(len(self.effects_to_keep))
301
+ else:
302
+ effect_indices = torch.arange(len(self.effects_to_keep))
303
+
304
+ r1 = self.num_kept_effects[0]
305
+ r2 = self.num_kept_effects[1]
306
+ num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
307
+ effect_indices = effect_indices[:num_kept_effects]
308
+ # Index in effect settings
309
+ effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
310
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
311
+ # Apply
312
+ dry_labels = []
313
+ for effect in effects_to_apply:
314
+ # Normalize in-between effects
315
+ dry = self.normalize(effect(dry))
316
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
317
+
318
+ # Apply effects_to_remove
319
+ # Shuffle effects if specified
320
+ if self.shuffle_removed_effects:
321
+ effect_indices = torch.randperm(len(self.effects_to_remove))
322
+ else:
323
+ effect_indices = torch.arange(len(self.effects_to_remove))
324
+ wet = torch.clone(dry)
325
+ r1 = self.num_removed_effects[0]
326
+ r2 = self.num_removed_effects[1]
327
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
328
+ effect_indices = effect_indices[:num_removed_effects]
329
+ # Index in effect settings
330
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
331
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
332
+ # Apply
333
+ wet_labels = []
334
+ for effect in effects_to_apply:
335
+ # Normalize in-between effects
336
+ wet = self.normalize(effect(wet))
337
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
338
+
339
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
340
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
341
+
342
+ for label_idx in wet_labels:
343
+ wet_labels_tensor[label_idx] = 1.0
344
+
345
+ for label_idx in dry_labels:
346
+ dry_labels_tensor[label_idx] = 1.0
347
+
348
+ # Normalize
349
+ normalized_dry = self.normalize(dry)
350
+ normalized_wet = self.normalize(wet)
351
+ return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
352
+
353
+ def __len__(self):
354
+ return self.total_chunks
355
+
356
+ def __getitem__(self, _: int):
357
+ chunk = None
358
+ random_dataset_choice = random.choice(self.files)
359
+ while chunk is None:
360
+ random_file_choice = random.choice(random_dataset_choice)
361
+ chunk = select_random_chunk(
362
+ random_file_choice, self.chunk_size, self.sample_rate
363
+ )
364
+
365
+ # Sum to mono
366
+ if chunk.shape[0] > 1:
367
+ chunk = chunk.sum(0, keepdim=True)
368
+
369
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
370
+
371
+ return wet, dry, dry_effects, wet_effects
372
+
373
+
374
  class EffectDataset(Dataset):
375
  def __init__(
376
  self,
 
388
  render_files: bool = True,
389
  render_root: str = None,
390
  mode: str = "train",
391
+ parallel: bool = False,
392
  ):
393
  super().__init__()
394
  self.chunks = []
 
653
  val_dataset,
654
  test_dataset,
655
  *,
656
+ train_batch_size: int,
657
+ test_batch_size: int,
658
  num_workers: int,
659
  pin_memory: bool = False,
660
  **kwargs: int,
 
663
  self.train_dataset = train_dataset
664
  self.val_dataset = val_dataset
665
  self.test_dataset = test_dataset
666
+ self.train_batch_size = train_batch_size
667
+ self.test_batch_size = test_batch_size
668
  self.num_workers = num_workers
669
  self.pin_memory = pin_memory
670
 
 
674
  def train_dataloader(self) -> DataLoader:
675
  return DataLoader(
676
  dataset=self.train_dataset,
677
+ batch_size=self.train_batch_size,
678
  num_workers=self.num_workers,
679
  pin_memory=self.pin_memory,
680
  shuffle=True,
 
683
  def val_dataloader(self) -> DataLoader:
684
  return DataLoader(
685
  dataset=self.val_dataset,
686
+ batch_size=self.train_batch_size,
687
  num_workers=self.num_workers,
688
  pin_memory=self.pin_memory,
689
  shuffle=False,
 
692
  def test_dataloader(self) -> DataLoader:
693
  return DataLoader(
694
  dataset=self.test_dataset,
695
+ batch_size=self.test_batch_size, # Use small, consistent batch size for testing
696
  num_workers=self.num_workers,
697
  pin_memory=self.pin_memory,
698
  shuffle=False,