Spaces:
Sleeping
Sleeping
Christian J. Steinmetz
commited on
Commit
·
279d167
1
Parent(s):
1b6bb59
adding support for parallel processing in dataset generation
Browse files- remfx/datasets.py +144 -22
remfx/datasets.py
CHANGED
@@ -8,15 +8,16 @@ import pytorch_lightning as pl
|
|
8 |
import random
|
9 |
from tqdm import tqdm
|
10 |
from pathlib import Path
|
11 |
-
from remfx import effects
|
12 |
from typing import Any, List, Dict
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
|
|
15 |
|
16 |
|
17 |
# https://zenodo.org/record/1193957 -> VocalSet
|
18 |
|
19 |
-
ALL_EFFECTS =
|
20 |
# print(ALL_EFFECTS)
|
21 |
|
22 |
|
@@ -146,6 +147,101 @@ def locate_files(root: str, mode: str):
|
|
146 |
return file_list
|
147 |
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
class EffectDataset(Dataset):
|
150 |
def __init__(
|
151 |
self,
|
@@ -163,6 +259,7 @@ class EffectDataset(Dataset):
|
|
163 |
render_files: bool = True,
|
164 |
render_root: str = None,
|
165 |
mode: str = "train",
|
|
|
166 |
):
|
167 |
super().__init__()
|
168 |
self.chunks = []
|
@@ -177,7 +274,7 @@ class EffectDataset(Dataset):
|
|
177 |
self.num_removed_effects = num_removed_effects
|
178 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
179 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
180 |
-
self.normalize =
|
181 |
self.effects = effect_modules
|
182 |
self.shuffle_kept_effects = shuffle_kept_effects
|
183 |
self.shuffle_removed_effects = shuffle_removed_effects
|
@@ -192,6 +289,7 @@ class EffectDataset(Dataset):
|
|
192 |
)
|
193 |
self.validate_effect_input()
|
194 |
self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
|
|
195 |
|
196 |
self.files = locate_files(self.root, self.mode)
|
197 |
|
@@ -212,26 +310,50 @@ class EffectDataset(Dataset):
|
|
212 |
if render_files:
|
213 |
# Split audio file into chunks, resample, then apply random effects
|
214 |
self.proc_root.mkdir(parents=True, exist_ok=True)
|
215 |
-
for num_chunk in tqdm(range(self.total_chunks)):
|
216 |
-
chunk = None
|
217 |
-
random_dataset_choice = random.choice(self.files)
|
218 |
-
while chunk is None:
|
219 |
-
random_file_choice = random.choice(random_dataset_choice)
|
220 |
-
chunk = select_random_chunk(
|
221 |
-
random_file_choice, self.chunk_size, self.sample_rate
|
222 |
-
)
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
print("Finished rendering")
|
237 |
else:
|
|
|
8 |
import random
|
9 |
from tqdm import tqdm
|
10 |
from pathlib import Path
|
11 |
+
from remfx import effects as effect_lib
|
12 |
from typing import Any, List, Dict
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
15 |
+
import multiprocessing
|
16 |
|
17 |
|
18 |
# https://zenodo.org/record/1193957 -> VocalSet
|
19 |
|
20 |
+
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
21 |
# print(ALL_EFFECTS)
|
22 |
|
23 |
|
|
|
147 |
return file_list
|
148 |
|
149 |
|
150 |
+
def parallel_process_effects(
|
151 |
+
chunk_idx: int,
|
152 |
+
proc_root: str,
|
153 |
+
files: list,
|
154 |
+
chunk_size: int,
|
155 |
+
effects: list,
|
156 |
+
effects_to_keep: list,
|
157 |
+
num_kept_effects: tuple,
|
158 |
+
shuffle_kept_effects: bool,
|
159 |
+
effects_to_remove: list,
|
160 |
+
num_removed_effects: tuple,
|
161 |
+
shuffle_removed_effects: bool,
|
162 |
+
sample_rate: int,
|
163 |
+
target_lufs_db: float,
|
164 |
+
):
|
165 |
+
chunk = None
|
166 |
+
random_dataset_choice = random.choice(files)
|
167 |
+
while chunk is None:
|
168 |
+
random_file_choice = random.choice(random_dataset_choice)
|
169 |
+
chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate)
|
170 |
+
|
171 |
+
# Sum to mono
|
172 |
+
if chunk.shape[0] > 1:
|
173 |
+
chunk = chunk.sum(0, keepdim=True)
|
174 |
+
|
175 |
+
dry = chunk
|
176 |
+
|
177 |
+
# loudness normalization
|
178 |
+
normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db)
|
179 |
+
|
180 |
+
# Apply Kept Effects
|
181 |
+
# Shuffle effects if specified
|
182 |
+
if shuffle_kept_effects:
|
183 |
+
effect_indices = torch.randperm(len(effects_to_keep))
|
184 |
+
else:
|
185 |
+
effect_indices = torch.arange(len(effects_to_keep))
|
186 |
+
|
187 |
+
r1 = num_kept_effects[0]
|
188 |
+
r2 = num_kept_effects[1]
|
189 |
+
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
190 |
+
effect_indices = effect_indices[:num_kept_effects]
|
191 |
+
# Index in effect settings
|
192 |
+
effect_names_to_apply = [effects_to_keep[i] for i in effect_indices]
|
193 |
+
effects_to_apply = [effects[i] for i in effect_names_to_apply]
|
194 |
+
# Apply
|
195 |
+
dry_labels = []
|
196 |
+
for effect in effects_to_apply:
|
197 |
+
# Normalize in-between effects
|
198 |
+
dry = normalize(effect(dry))
|
199 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
200 |
+
|
201 |
+
# Apply effects_to_remove
|
202 |
+
# Shuffle effects if specified
|
203 |
+
if shuffle_removed_effects:
|
204 |
+
effect_indices = torch.randperm(len(effects_to_remove))
|
205 |
+
else:
|
206 |
+
effect_indices = torch.arange(len(effects_to_remove))
|
207 |
+
wet = torch.clone(dry)
|
208 |
+
r1 = num_removed_effects[0]
|
209 |
+
r2 = num_removed_effects[1]
|
210 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
211 |
+
effect_indices = effect_indices[:num_removed_effects]
|
212 |
+
# Index in effect settings
|
213 |
+
effect_names_to_apply = [effects_to_remove[i] for i in effect_indices]
|
214 |
+
effects_to_apply = [effects[i] for i in effect_names_to_apply]
|
215 |
+
# Apply
|
216 |
+
wet_labels = []
|
217 |
+
for effect in effects_to_apply:
|
218 |
+
# Normalize in-between effects
|
219 |
+
wet = normalize(effect(wet))
|
220 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
221 |
+
|
222 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
223 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
224 |
+
|
225 |
+
for label_idx in wet_labels:
|
226 |
+
wet_labels_tensor[label_idx] = 1.0
|
227 |
+
|
228 |
+
for label_idx in dry_labels:
|
229 |
+
dry_labels_tensor[label_idx] = 1.0
|
230 |
+
|
231 |
+
# Normalize
|
232 |
+
normalized_dry = normalize(dry)
|
233 |
+
normalized_wet = normalize(wet)
|
234 |
+
|
235 |
+
output_dir = proc_root / str(chunk_idx)
|
236 |
+
output_dir.mkdir(exist_ok=True)
|
237 |
+
torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate)
|
238 |
+
torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate)
|
239 |
+
torch.save(dry_labels_tensor, output_dir / "dry_effects.pt")
|
240 |
+
torch.save(wet_labels_tensor, output_dir / "wet_effects.pt")
|
241 |
+
|
242 |
+
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
243 |
+
|
244 |
+
|
245 |
class EffectDataset(Dataset):
|
246 |
def __init__(
|
247 |
self,
|
|
|
259 |
render_files: bool = True,
|
260 |
render_root: str = None,
|
261 |
mode: str = "train",
|
262 |
+
parallel: bool = True,
|
263 |
):
|
264 |
super().__init__()
|
265 |
self.chunks = []
|
|
|
274 |
self.num_removed_effects = num_removed_effects
|
275 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
276 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
277 |
+
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
278 |
self.effects = effect_modules
|
279 |
self.shuffle_kept_effects = shuffle_kept_effects
|
280 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
|
289 |
)
|
290 |
self.validate_effect_input()
|
291 |
self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
292 |
+
self.parallel = parallel
|
293 |
|
294 |
self.files = locate_files(self.root, self.mode)
|
295 |
|
|
|
310 |
if render_files:
|
311 |
# Split audio file into chunks, resample, then apply random effects
|
312 |
self.proc_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
+
if self.parallel:
|
315 |
+
items = [
|
316 |
+
(
|
317 |
+
chunk_idx,
|
318 |
+
self.proc_root,
|
319 |
+
self.files,
|
320 |
+
self.chunk_size,
|
321 |
+
self.effects,
|
322 |
+
self.effects_to_keep,
|
323 |
+
self.num_kept_effects,
|
324 |
+
self.shuffle_kept_effects,
|
325 |
+
self.effects_to_remove,
|
326 |
+
self.num_removed_effects,
|
327 |
+
self.shuffle_removed_effects,
|
328 |
+
self.sample_rate,
|
329 |
+
-20.0,
|
330 |
+
)
|
331 |
+
for chunk_idx in range(self.total_chunks)
|
332 |
+
]
|
333 |
+
with multiprocessing.Pool(processes=32) as pool:
|
334 |
+
pool.starmap(parallel_process_effects, items)
|
335 |
+
print(f"Done proccessing {self.total_chunks}", flush=True)
|
336 |
+
else:
|
337 |
+
for num_chunk in tqdm(range(self.total_chunks)):
|
338 |
+
chunk = None
|
339 |
+
random_dataset_choice = random.choice(self.files)
|
340 |
+
while chunk is None:
|
341 |
+
random_file_choice = random.choice(random_dataset_choice)
|
342 |
+
chunk = select_random_chunk(
|
343 |
+
random_file_choice, self.chunk_size, self.sample_rate
|
344 |
+
)
|
345 |
+
|
346 |
+
# Sum to mono
|
347 |
+
if chunk.shape[0] > 1:
|
348 |
+
chunk = chunk.sum(0, keepdim=True)
|
349 |
+
|
350 |
+
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
|
351 |
+
output_dir = self.proc_root / str(num_chunk)
|
352 |
+
output_dir.mkdir(exist_ok=True)
|
353 |
+
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
|
354 |
+
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
|
355 |
+
torch.save(dry_effects, output_dir / "dry_effects.pt")
|
356 |
+
torch.save(wet_effects, output_dir / "wet_effects.pt")
|
357 |
|
358 |
print("Finished rendering")
|
359 |
else:
|