Spaces:
Sleeping
Sleeping
Christian J. Steinmetz
commited on
Commit
·
a3e84f7
1
Parent(s):
99511d7
rework dataset generation for better logging of effects added
Browse files- remfx/datasets.py +77 -28
- remfx/effects.py +3 -3
remfx/datasets.py
CHANGED
@@ -1,21 +1,49 @@
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
-
|
3 |
-
import torch.nn.functional as F
|
4 |
import torchaudio
|
5 |
-
from pathlib import Path
|
6 |
import pytorch_lightning as pl
|
7 |
-
import
|
8 |
-
|
9 |
-
from remfx import effects
|
10 |
from tqdm import tqdm
|
11 |
-
from
|
12 |
-
import
|
13 |
from ordered_set import OrderedSet
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
# https://zenodo.org/record/1193957 -> VocalSet
|
17 |
|
18 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
class VocalSet(Dataset):
|
@@ -43,8 +71,6 @@ class VocalSet(Dataset):
|
|
43 |
self.chunk_size = chunk_size
|
44 |
self.sample_rate = sample_rate
|
45 |
self.mode = mode
|
46 |
-
mode_path = self.root / self.mode
|
47 |
-
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
48 |
self.max_kept_effects = max_kept_effects
|
49 |
self.max_removed_effects = max_removed_effects
|
50 |
self.effects_to_use = effects_to_use
|
@@ -53,11 +79,20 @@ class VocalSet(Dataset):
|
|
53 |
self.effects = effect_modules
|
54 |
self.shuffle_kept_effects = shuffle_kept_effects
|
55 |
self.shuffle_removed_effects = shuffle_removed_effects
|
56 |
-
|
57 |
effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
|
58 |
self.effects_to_keep = self.validate_effect_input()
|
59 |
self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
|
62 |
print("Found processed files.")
|
63 |
if render_files:
|
@@ -86,12 +121,15 @@ class VocalSet(Dataset):
|
|
86 |
# Skip if chunk is too small
|
87 |
continue
|
88 |
|
89 |
-
|
|
|
|
|
90 |
output_dir = self.proc_root / str(self.num_chunks)
|
91 |
output_dir.mkdir(exist_ok=True)
|
92 |
-
torchaudio.save(output_dir / "input.wav",
|
93 |
-
torchaudio.save(output_dir / "target.wav",
|
94 |
-
torch.save(
|
|
|
95 |
self.num_chunks += 1
|
96 |
else:
|
97 |
self.num_chunks = len(list(self.proc_root.iterdir()))
|
@@ -107,10 +145,11 @@ class VocalSet(Dataset):
|
|
107 |
def __getitem__(self, idx):
|
108 |
input_file = self.proc_root / str(idx) / "input.wav"
|
109 |
target_file = self.proc_root / str(idx) / "target.wav"
|
110 |
-
|
|
|
111 |
input, sr = torchaudio.load(input_file)
|
112 |
target, sr = torchaudio.load(target_file)
|
113 |
-
return (input, target,
|
114 |
|
115 |
def validate_effect_input(self):
|
116 |
for effect in self.effects.values():
|
@@ -154,27 +193,29 @@ class VocalSet(Dataset):
|
|
154 |
return kept_fx
|
155 |
|
156 |
def process_effects(self, dry: torch.Tensor):
|
157 |
-
labels = []
|
158 |
-
|
159 |
# Apply Kept Effects
|
160 |
# Shuffle effects if specified
|
161 |
if self.shuffle_kept_effects:
|
162 |
effect_indices = torch.randperm(len(self.effects_to_keep))
|
163 |
else:
|
164 |
effect_indices = torch.arange(len(self.effects_to_keep))
|
|
|
165 |
# Up to max_kept_effects
|
166 |
if self.max_kept_effects != -1:
|
167 |
num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
|
168 |
else:
|
169 |
num_kept_effects = len(self.effects_to_keep)
|
170 |
effect_indices = effect_indices[:num_kept_effects]
|
|
|
|
|
171 |
# Index in effect settings
|
172 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
173 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
174 |
# Apply
|
|
|
175 |
for effect in effects_to_apply:
|
176 |
dry = effect(dry)
|
177 |
-
|
178 |
|
179 |
# Apply effects_to_remove
|
180 |
# Shuffle effects if specified
|
@@ -185,9 +226,7 @@ class VocalSet(Dataset):
|
|
185 |
effect_indices = torch.arange(len(self.effects_to_remove))
|
186 |
# Up to max_removed_effects
|
187 |
if self.max_removed_effects != -1:
|
188 |
-
num_kept_effects = (
|
189 |
-
int(torch.rand(1).item() * (self.max_removed_effects)) + 1
|
190 |
-
)
|
191 |
else:
|
192 |
num_kept_effects = len(self.effects_to_remove)
|
193 |
effect_indices = effect_indices[: self.max_removed_effects]
|
@@ -195,17 +234,27 @@ class VocalSet(Dataset):
|
|
195 |
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
196 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
197 |
# Apply
|
|
|
|
|
198 |
for effect in effects_to_apply:
|
199 |
wet = effect(wet)
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
-
#
|
203 |
-
|
204 |
-
effects_present = torch.sum(one_hot, dim=0).float()
|
205 |
# Normalize
|
206 |
normalized_dry = self.normalize(dry)
|
207 |
normalized_wet = self.normalize(wet)
|
208 |
-
return normalized_dry, normalized_wet,
|
209 |
|
210 |
|
211 |
class VocalSetDatamodule(pl.LightningDataModule):
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import glob
|
4 |
import torch
|
5 |
+
import shutil
|
|
|
6 |
import torchaudio
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
|
|
10 |
from tqdm import tqdm
|
11 |
+
from pathlib import Path
|
12 |
+
from remfx import effects
|
13 |
from ordered_set import OrderedSet
|
14 |
+
from typing import Any, List, Dict
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
from remfx.utils import create_sequential_chunks
|
17 |
|
18 |
|
19 |
# https://zenodo.org/record/1193957 -> VocalSet
|
20 |
|
21 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
22 |
+
print(ALL_EFFECTS)
|
23 |
+
|
24 |
+
|
25 |
+
singer_splits = {
|
26 |
+
"train": [
|
27 |
+
"male1",
|
28 |
+
"male2",
|
29 |
+
"male3",
|
30 |
+
"male4",
|
31 |
+
"male5",
|
32 |
+
"male6",
|
33 |
+
"male7",
|
34 |
+
"male8",
|
35 |
+
"male9",
|
36 |
+
"female1",
|
37 |
+
"female2",
|
38 |
+
"female3",
|
39 |
+
"female4",
|
40 |
+
"female5",
|
41 |
+
"female6",
|
42 |
+
"female7",
|
43 |
+
],
|
44 |
+
"val": ["male10", "female8"],
|
45 |
+
"test": ["male11", "female9"],
|
46 |
+
}
|
47 |
|
48 |
|
49 |
class VocalSet(Dataset):
|
|
|
71 |
self.chunk_size = chunk_size
|
72 |
self.sample_rate = sample_rate
|
73 |
self.mode = mode
|
|
|
|
|
74 |
self.max_kept_effects = max_kept_effects
|
75 |
self.max_removed_effects = max_removed_effects
|
76 |
self.effects_to_use = effects_to_use
|
|
|
79 |
self.effects = effect_modules
|
80 |
self.shuffle_kept_effects = shuffle_kept_effects
|
81 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
|
82 |
effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
|
83 |
self.effects_to_keep = self.validate_effect_input()
|
84 |
self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
85 |
|
86 |
+
# find all singer directories
|
87 |
+
singer_dirs = glob.glob(os.path.join(self.root, "data_by_singer", "*"))
|
88 |
+
singer_dirs = [
|
89 |
+
sd for sd in singer_dirs if os.path.basename(sd) in singer_splits[mode]
|
90 |
+
]
|
91 |
+
self.files = []
|
92 |
+
for singer_dir in singer_dirs:
|
93 |
+
self.files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
|
94 |
+
self.files = sorted(self.files)
|
95 |
+
|
96 |
if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
|
97 |
print("Found processed files.")
|
98 |
if render_files:
|
|
|
121 |
# Skip if chunk is too small
|
122 |
continue
|
123 |
|
124 |
+
dry, wet, dry_effects, wet_effects = self.process_effects(
|
125 |
+
resampled_chunk
|
126 |
+
)
|
127 |
output_dir = self.proc_root / str(self.num_chunks)
|
128 |
output_dir.mkdir(exist_ok=True)
|
129 |
+
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
|
130 |
+
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
|
131 |
+
torch.save(dry_effects, output_dir / "dry_effects.pt")
|
132 |
+
torch.save(wet_effects, output_dir / "wet_effects.pt")
|
133 |
self.num_chunks += 1
|
134 |
else:
|
135 |
self.num_chunks = len(list(self.proc_root.iterdir()))
|
|
|
145 |
def __getitem__(self, idx):
|
146 |
input_file = self.proc_root / str(idx) / "input.wav"
|
147 |
target_file = self.proc_root / str(idx) / "target.wav"
|
148 |
+
dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt")
|
149 |
+
wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt")
|
150 |
input, sr = torchaudio.load(input_file)
|
151 |
target, sr = torchaudio.load(target_file)
|
152 |
+
return (input, target, dry_effect_names, wet_effect_names)
|
153 |
|
154 |
def validate_effect_input(self):
|
155 |
for effect in self.effects.values():
|
|
|
193 |
return kept_fx
|
194 |
|
195 |
def process_effects(self, dry: torch.Tensor):
|
|
|
|
|
196 |
# Apply Kept Effects
|
197 |
# Shuffle effects if specified
|
198 |
if self.shuffle_kept_effects:
|
199 |
effect_indices = torch.randperm(len(self.effects_to_keep))
|
200 |
else:
|
201 |
effect_indices = torch.arange(len(self.effects_to_keep))
|
202 |
+
|
203 |
# Up to max_kept_effects
|
204 |
if self.max_kept_effects != -1:
|
205 |
num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
|
206 |
else:
|
207 |
num_kept_effects = len(self.effects_to_keep)
|
208 |
effect_indices = effect_indices[:num_kept_effects]
|
209 |
+
print(effect_indices)
|
210 |
+
|
211 |
# Index in effect settings
|
212 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
213 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
214 |
# Apply
|
215 |
+
dry_labels = []
|
216 |
for effect in effects_to_apply:
|
217 |
dry = effect(dry)
|
218 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
219 |
|
220 |
# Apply effects_to_remove
|
221 |
# Shuffle effects if specified
|
|
|
226 |
effect_indices = torch.arange(len(self.effects_to_remove))
|
227 |
# Up to max_removed_effects
|
228 |
if self.max_removed_effects != -1:
|
229 |
+
num_kept_effects = int(torch.rand(1).item() * (self.max_removed_effects))
|
|
|
|
|
230 |
else:
|
231 |
num_kept_effects = len(self.effects_to_remove)
|
232 |
effect_indices = effect_indices[: self.max_removed_effects]
|
|
|
234 |
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
235 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
236 |
# Apply
|
237 |
+
|
238 |
+
wet_labels = []
|
239 |
for effect in effects_to_apply:
|
240 |
wet = effect(wet)
|
241 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
242 |
+
|
243 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
244 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
245 |
+
|
246 |
+
for label_idx in wet_labels:
|
247 |
+
wet_labels_tensor[label_idx] = 1.0
|
248 |
+
|
249 |
+
for label_idx in dry_labels:
|
250 |
+
dry_labels_tensor[label_idx] = 1.0
|
251 |
|
252 |
+
# effects_present = torch.sum(one_hot, dim=0).float()
|
253 |
+
print(dry_labels_tensor, wet_labels_tensor)
|
|
|
254 |
# Normalize
|
255 |
normalized_dry = self.normalize(dry)
|
256 |
normalized_wet = self.normalize(wet)
|
257 |
+
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
258 |
|
259 |
|
260 |
class VocalSetDatamodule(pl.LightningDataModule):
|
remfx/effects.py
CHANGED
@@ -701,9 +701,9 @@ class RandomAudioEffectsChannel(torch.nn.Module):
|
|
701 |
Pedalboard_Effects = [
|
702 |
RandomPedalboardReverb,
|
703 |
RandomPedalboardChorus,
|
704 |
-
RandomPedalboardDelay,
|
705 |
RandomPedalboardDistortion,
|
706 |
RandomPedalboardCompressor,
|
707 |
-
RandomPedalboardPhaser,
|
708 |
-
RandomPedalboardLimiter,
|
709 |
]
|
|
|
701 |
Pedalboard_Effects = [
|
702 |
RandomPedalboardReverb,
|
703 |
RandomPedalboardChorus,
|
704 |
+
# RandomPedalboardDelay,
|
705 |
RandomPedalboardDistortion,
|
706 |
RandomPedalboardCompressor,
|
707 |
+
# RandomPedalboardPhaser,
|
708 |
+
# RandomPedalboardLimiter,
|
709 |
]
|