Spaces:
Sleeping
Sleeping
Commit
·
7173f20
1
Parent(s):
7fc4de1
Re-sample effects if STFT too low
Browse files- remfx/datasets.py +49 -42
remfx/datasets.py
CHANGED
@@ -13,10 +13,10 @@ from typing import Any, List, Dict
|
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
15 |
import multiprocessing
|
|
|
16 |
|
17 |
|
18 |
-
|
19 |
-
|
20 |
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
21 |
# print(ALL_EFFECTS)
|
22 |
|
@@ -275,6 +275,7 @@ class EffectDataset(Dataset):
|
|
275 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
276 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
277 |
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
|
|
278 |
self.effects = effect_modules
|
279 |
self.shuffle_kept_effects = shuffle_kept_effects
|
280 |
self.shuffle_removed_effects = shuffle_removed_effects
|
@@ -438,46 +439,52 @@ class EffectDataset(Dataset):
|
|
438 |
# Index in effect settings
|
439 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
440 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
#
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
#
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
482 |
|
483 |
|
|
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
15 |
import multiprocessing
|
16 |
+
from auraloss.freq import MultiResolutionSTFTLoss
|
17 |
|
18 |
|
19 |
+
STFT_THRESH = 1e-3
|
|
|
20 |
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
21 |
# print(ALL_EFFECTS)
|
22 |
|
|
|
275 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
276 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
277 |
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
278 |
+
self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate)
|
279 |
self.effects = effect_modules
|
280 |
self.shuffle_kept_effects = shuffle_kept_effects
|
281 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
|
439 |
# Index in effect settings
|
440 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
441 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
442 |
+
# stft comparison
|
443 |
+
stft = 0
|
444 |
+
while stft < STFT_THRESH:
|
445 |
+
# Apply
|
446 |
+
dry_labels = []
|
447 |
+
for effect in effects_to_apply:
|
448 |
+
# Normalize in-between effects
|
449 |
+
dry = self.normalize(effect(dry))
|
450 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
451 |
+
|
452 |
+
# Apply effects_to_remove
|
453 |
+
# Shuffle effects if specified
|
454 |
+
if self.shuffle_removed_effects:
|
455 |
+
effect_indices = torch.randperm(len(self.effects_to_remove))
|
456 |
+
else:
|
457 |
+
effect_indices = torch.arange(len(self.effects_to_remove))
|
458 |
+
wet = torch.clone(dry)
|
459 |
+
r1 = self.num_removed_effects[0]
|
460 |
+
r2 = self.num_removed_effects[1]
|
461 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
462 |
+
effect_indices = effect_indices[:num_removed_effects]
|
463 |
+
# Index in effect settings
|
464 |
+
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
465 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
466 |
+
# Apply
|
467 |
+
wet_labels = []
|
468 |
+
for effect in effects_to_apply:
|
469 |
+
# Normalize in-between effects
|
470 |
+
wet = self.normalize(effect(wet))
|
471 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
472 |
+
|
473 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
474 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
475 |
+
|
476 |
+
for label_idx in wet_labels:
|
477 |
+
wet_labels_tensor[label_idx] = 1.0
|
478 |
+
|
479 |
+
for label_idx in dry_labels:
|
480 |
+
dry_labels_tensor[label_idx] = 1.0
|
481 |
+
|
482 |
+
# Normalize
|
483 |
+
normalized_dry = self.normalize(dry)
|
484 |
+
normalized_wet = self.normalize(wet)
|
485 |
+
|
486 |
+
# Check STFT, pick different effects if necessary
|
487 |
+
stft = self.mrstft(normalized_wet, normalized_dry)
|
488 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
489 |
|
490 |
|