Spaces:
Sleeping
Sleeping
Christian J. Steinmetz
commited on
Commit
·
507048e
1
Parent(s):
29f23c3
adding a dynamic dataset for on the fly generation of training data
Browse files- remfx/datasets.py +137 -6
remfx/datasets.py
CHANGED
@@ -162,6 +162,7 @@ def parallel_process_effects(
|
|
162 |
sample_rate: int,
|
163 |
target_lufs_db: float,
|
164 |
):
|
|
|
165 |
chunk = None
|
166 |
random_dataset_choice = random.choice(files)
|
167 |
while chunk is None:
|
@@ -242,6 +243,134 @@ def parallel_process_effects(
|
|
242 |
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
243 |
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
class EffectDataset(Dataset):
|
246 |
def __init__(
|
247 |
self,
|
@@ -259,7 +388,7 @@ class EffectDataset(Dataset):
|
|
259 |
render_files: bool = True,
|
260 |
render_root: str = None,
|
261 |
mode: str = "train",
|
262 |
-
parallel: bool =
|
263 |
):
|
264 |
super().__init__()
|
265 |
self.chunks = []
|
@@ -524,7 +653,8 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
524 |
val_dataset,
|
525 |
test_dataset,
|
526 |
*,
|
527 |
-
|
|
|
528 |
num_workers: int,
|
529 |
pin_memory: bool = False,
|
530 |
**kwargs: int,
|
@@ -533,7 +663,8 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
533 |
self.train_dataset = train_dataset
|
534 |
self.val_dataset = val_dataset
|
535 |
self.test_dataset = test_dataset
|
536 |
-
self.
|
|
|
537 |
self.num_workers = num_workers
|
538 |
self.pin_memory = pin_memory
|
539 |
|
@@ -543,7 +674,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
543 |
def train_dataloader(self) -> DataLoader:
|
544 |
return DataLoader(
|
545 |
dataset=self.train_dataset,
|
546 |
-
batch_size=self.
|
547 |
num_workers=self.num_workers,
|
548 |
pin_memory=self.pin_memory,
|
549 |
shuffle=True,
|
@@ -552,7 +683,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
552 |
def val_dataloader(self) -> DataLoader:
|
553 |
return DataLoader(
|
554 |
dataset=self.val_dataset,
|
555 |
-
batch_size=self.
|
556 |
num_workers=self.num_workers,
|
557 |
pin_memory=self.pin_memory,
|
558 |
shuffle=False,
|
@@ -561,7 +692,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
561 |
def test_dataloader(self) -> DataLoader:
|
562 |
return DataLoader(
|
563 |
dataset=self.test_dataset,
|
564 |
-
batch_size=
|
565 |
num_workers=self.num_workers,
|
566 |
pin_memory=self.pin_memory,
|
567 |
shuffle=False,
|
|
|
162 |
sample_rate: int,
|
163 |
target_lufs_db: float,
|
164 |
):
|
165 |
+
"""Note: This function has an issue with random seed. It may not fully randomize the effects."""
|
166 |
chunk = None
|
167 |
random_dataset_choice = random.choice(files)
|
168 |
while chunk is None:
|
|
|
243 |
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
244 |
|
245 |
|
246 |
+
class DynamicEffectDataset(Dataset):
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
root: str,
|
250 |
+
sample_rate: int,
|
251 |
+
chunk_size: int = 262144,
|
252 |
+
total_chunks: int = 1000,
|
253 |
+
effect_modules: List[Dict[str, torch.nn.Module]] = None,
|
254 |
+
effects_to_keep: List[str] = None,
|
255 |
+
effects_to_remove: List[str] = None,
|
256 |
+
num_kept_effects: List[int] = [1, 5],
|
257 |
+
num_removed_effects: List[int] = [1, 5],
|
258 |
+
shuffle_kept_effects: bool = True,
|
259 |
+
shuffle_removed_effects: bool = False,
|
260 |
+
render_files: bool = True,
|
261 |
+
render_root: str = None,
|
262 |
+
mode: str = "train",
|
263 |
+
parallel: bool = False,
|
264 |
+
) -> None:
|
265 |
+
super().__init__()
|
266 |
+
self.chunks = []
|
267 |
+
self.song_idx = []
|
268 |
+
self.root = Path(root)
|
269 |
+
self.render_root = Path(render_root)
|
270 |
+
self.chunk_size = chunk_size
|
271 |
+
self.total_chunks = total_chunks
|
272 |
+
self.sample_rate = sample_rate
|
273 |
+
self.mode = mode
|
274 |
+
self.num_kept_effects = num_kept_effects
|
275 |
+
self.num_removed_effects = num_removed_effects
|
276 |
+
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
277 |
+
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
278 |
+
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
279 |
+
self.effects = effect_modules
|
280 |
+
self.shuffle_kept_effects = shuffle_kept_effects
|
281 |
+
self.shuffle_removed_effects = shuffle_removed_effects
|
282 |
+
effects_string = "_".join(
|
283 |
+
self.effects_to_keep
|
284 |
+
+ ["_"]
|
285 |
+
+ self.effects_to_remove
|
286 |
+
+ ["_"]
|
287 |
+
+ [str(x) for x in num_kept_effects]
|
288 |
+
+ ["_"]
|
289 |
+
+ [str(x) for x in num_removed_effects]
|
290 |
+
)
|
291 |
+
# self.validate_effect_input()
|
292 |
+
# self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
293 |
+
self.parallel = parallel
|
294 |
+
self.files = locate_files(self.root, self.mode)
|
295 |
+
|
296 |
+
def process_effects(self, dry: torch.Tensor):
|
297 |
+
# Apply Kept Effects
|
298 |
+
# Shuffle effects if specified
|
299 |
+
if self.shuffle_kept_effects:
|
300 |
+
effect_indices = torch.randperm(len(self.effects_to_keep))
|
301 |
+
else:
|
302 |
+
effect_indices = torch.arange(len(self.effects_to_keep))
|
303 |
+
|
304 |
+
r1 = self.num_kept_effects[0]
|
305 |
+
r2 = self.num_kept_effects[1]
|
306 |
+
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
307 |
+
effect_indices = effect_indices[:num_kept_effects]
|
308 |
+
# Index in effect settings
|
309 |
+
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
310 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
311 |
+
# Apply
|
312 |
+
dry_labels = []
|
313 |
+
for effect in effects_to_apply:
|
314 |
+
# Normalize in-between effects
|
315 |
+
dry = self.normalize(effect(dry))
|
316 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
317 |
+
|
318 |
+
# Apply effects_to_remove
|
319 |
+
# Shuffle effects if specified
|
320 |
+
if self.shuffle_removed_effects:
|
321 |
+
effect_indices = torch.randperm(len(self.effects_to_remove))
|
322 |
+
else:
|
323 |
+
effect_indices = torch.arange(len(self.effects_to_remove))
|
324 |
+
wet = torch.clone(dry)
|
325 |
+
r1 = self.num_removed_effects[0]
|
326 |
+
r2 = self.num_removed_effects[1]
|
327 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
328 |
+
effect_indices = effect_indices[:num_removed_effects]
|
329 |
+
# Index in effect settings
|
330 |
+
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
331 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
332 |
+
# Apply
|
333 |
+
wet_labels = []
|
334 |
+
for effect in effects_to_apply:
|
335 |
+
# Normalize in-between effects
|
336 |
+
wet = self.normalize(effect(wet))
|
337 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
338 |
+
|
339 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
340 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
341 |
+
|
342 |
+
for label_idx in wet_labels:
|
343 |
+
wet_labels_tensor[label_idx] = 1.0
|
344 |
+
|
345 |
+
for label_idx in dry_labels:
|
346 |
+
dry_labels_tensor[label_idx] = 1.0
|
347 |
+
|
348 |
+
# Normalize
|
349 |
+
normalized_dry = self.normalize(dry)
|
350 |
+
normalized_wet = self.normalize(wet)
|
351 |
+
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
352 |
+
|
353 |
+
def __len__(self):
|
354 |
+
return self.total_chunks
|
355 |
+
|
356 |
+
def __getitem__(self, _: int):
|
357 |
+
chunk = None
|
358 |
+
random_dataset_choice = random.choice(self.files)
|
359 |
+
while chunk is None:
|
360 |
+
random_file_choice = random.choice(random_dataset_choice)
|
361 |
+
chunk = select_random_chunk(
|
362 |
+
random_file_choice, self.chunk_size, self.sample_rate
|
363 |
+
)
|
364 |
+
|
365 |
+
# Sum to mono
|
366 |
+
if chunk.shape[0] > 1:
|
367 |
+
chunk = chunk.sum(0, keepdim=True)
|
368 |
+
|
369 |
+
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
|
370 |
+
|
371 |
+
return wet, dry, dry_effects, wet_effects
|
372 |
+
|
373 |
+
|
374 |
class EffectDataset(Dataset):
|
375 |
def __init__(
|
376 |
self,
|
|
|
388 |
render_files: bool = True,
|
389 |
render_root: str = None,
|
390 |
mode: str = "train",
|
391 |
+
parallel: bool = False,
|
392 |
):
|
393 |
super().__init__()
|
394 |
self.chunks = []
|
|
|
653 |
val_dataset,
|
654 |
test_dataset,
|
655 |
*,
|
656 |
+
train_batch_size: int,
|
657 |
+
test_batch_size: int,
|
658 |
num_workers: int,
|
659 |
pin_memory: bool = False,
|
660 |
**kwargs: int,
|
|
|
663 |
self.train_dataset = train_dataset
|
664 |
self.val_dataset = val_dataset
|
665 |
self.test_dataset = test_dataset
|
666 |
+
self.train_batch_size = train_batch_size
|
667 |
+
self.test_batch_size = test_batch_size
|
668 |
self.num_workers = num_workers
|
669 |
self.pin_memory = pin_memory
|
670 |
|
|
|
674 |
def train_dataloader(self) -> DataLoader:
|
675 |
return DataLoader(
|
676 |
dataset=self.train_dataset,
|
677 |
+
batch_size=self.train_batch_size,
|
678 |
num_workers=self.num_workers,
|
679 |
pin_memory=self.pin_memory,
|
680 |
shuffle=True,
|
|
|
683 |
def val_dataloader(self) -> DataLoader:
|
684 |
return DataLoader(
|
685 |
dataset=self.val_dataset,
|
686 |
+
batch_size=self.train_batch_size,
|
687 |
num_workers=self.num_workers,
|
688 |
pin_memory=self.pin_memory,
|
689 |
shuffle=False,
|
|
|
692 |
def test_dataloader(self) -> DataLoader:
|
693 |
return DataLoader(
|
694 |
dataset=self.test_dataset,
|
695 |
+
batch_size=self.test_batch_size, # Use small, consistent batch size for testing
|
696 |
num_workers=self.num_workers,
|
697 |
pin_memory=self.pin_memory,
|
698 |
shuffle=False,
|