mattricesound commited on
Commit
7173f20
·
1 Parent(s): 7fc4de1

Re-sample effects if STFT too low

Browse files
Files changed (1) hide show
  1. remfx/datasets.py +49 -42
remfx/datasets.py CHANGED
@@ -13,10 +13,10 @@ from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
  import multiprocessing
 
16
 
17
 
18
- # https://zenodo.org/record/1193957 -> VocalSet
19
-
20
  ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
@@ -275,6 +275,7 @@ class EffectDataset(Dataset):
275
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
276
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
277
  self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
 
278
  self.effects = effect_modules
279
  self.shuffle_kept_effects = shuffle_kept_effects
280
  self.shuffle_removed_effects = shuffle_removed_effects
@@ -438,46 +439,52 @@ class EffectDataset(Dataset):
438
  # Index in effect settings
439
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
440
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
441
- # Apply
442
- dry_labels = []
443
- for effect in effects_to_apply:
444
- # Normalize in-between effects
445
- dry = self.normalize(effect(dry))
446
- dry_labels.append(ALL_EFFECTS.index(type(effect)))
447
-
448
- # Apply effects_to_remove
449
- # Shuffle effects if specified
450
- if self.shuffle_removed_effects:
451
- effect_indices = torch.randperm(len(self.effects_to_remove))
452
- else:
453
- effect_indices = torch.arange(len(self.effects_to_remove))
454
- wet = torch.clone(dry)
455
- r1 = self.num_removed_effects[0]
456
- r2 = self.num_removed_effects[1]
457
- num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
458
- effect_indices = effect_indices[:num_removed_effects]
459
- # Index in effect settings
460
- effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
461
- effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
462
- # Apply
463
- wet_labels = []
464
- for effect in effects_to_apply:
465
- # Normalize in-between effects
466
- wet = self.normalize(effect(wet))
467
- wet_labels.append(ALL_EFFECTS.index(type(effect)))
468
-
469
- wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
470
- dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
471
-
472
- for label_idx in wet_labels:
473
- wet_labels_tensor[label_idx] = 1.0
474
-
475
- for label_idx in dry_labels:
476
- dry_labels_tensor[label_idx] = 1.0
477
-
478
- # Normalize
479
- normalized_dry = self.normalize(dry)
480
- normalized_wet = self.normalize(wet)
 
 
 
 
 
 
481
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
482
 
483
 
 
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
  import multiprocessing
16
+ from auraloss.freq import MultiResolutionSTFTLoss
17
 
18
 
19
+ STFT_THRESH = 1e-3
 
20
  ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
 
275
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
276
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
277
  self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
278
+ self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate)
279
  self.effects = effect_modules
280
  self.shuffle_kept_effects = shuffle_kept_effects
281
  self.shuffle_removed_effects = shuffle_removed_effects
 
439
  # Index in effect settings
440
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
441
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
442
+ # stft comparison
443
+ stft = 0
444
+ while stft < STFT_THRESH:
445
+ # Apply
446
+ dry_labels = []
447
+ for effect in effects_to_apply:
448
+ # Normalize in-between effects
449
+ dry = self.normalize(effect(dry))
450
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
451
+
452
+ # Apply effects_to_remove
453
+ # Shuffle effects if specified
454
+ if self.shuffle_removed_effects:
455
+ effect_indices = torch.randperm(len(self.effects_to_remove))
456
+ else:
457
+ effect_indices = torch.arange(len(self.effects_to_remove))
458
+ wet = torch.clone(dry)
459
+ r1 = self.num_removed_effects[0]
460
+ r2 = self.num_removed_effects[1]
461
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
462
+ effect_indices = effect_indices[:num_removed_effects]
463
+ # Index in effect settings
464
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
465
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
466
+ # Apply
467
+ wet_labels = []
468
+ for effect in effects_to_apply:
469
+ # Normalize in-between effects
470
+ wet = self.normalize(effect(wet))
471
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
472
+
473
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
474
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
475
+
476
+ for label_idx in wet_labels:
477
+ wet_labels_tensor[label_idx] = 1.0
478
+
479
+ for label_idx in dry_labels:
480
+ dry_labels_tensor[label_idx] = 1.0
481
+
482
+ # Normalize
483
+ normalized_dry = self.normalize(dry)
484
+ normalized_wet = self.normalize(wet)
485
+
486
+ # Check STFT, pick different effects if necessary
487
+ stft = self.mrstft(normalized_wet, normalized_dry)
488
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
489
 
490