RemFx / remfx /datasets.py
mattricesound's picture
Enhance dataset unique string2
8a93db1
raw
history blame
15.3 kB
import os
import sys
import glob
import torch
import shutil
import torchaudio
import pytorch_lightning as pl
import random
from tqdm import tqdm
from pathlib import Path
from remfx import effects
from typing import Any, List, Dict
from torch.utils.data import Dataset, DataLoader
from remfx.utils import select_random_chunk
# https://zenodo.org/record/1193957 -> VocalSet
ALL_EFFECTS = effects.Pedalboard_Effects
# print(ALL_EFFECTS)
vocalset_splits = {
"train": [
"male1",
"male2",
"male3",
"male4",
"male5",
"male6",
"male7",
"male8",
"male9",
"female1",
"female2",
"female3",
"female4",
"female5",
"female6",
"female7",
],
"val": ["male10", "female8"],
"test": ["male11", "female9"],
}
guitarset_splits = {"train": ["00", "01", "02", "03"], "val": ["04"], "test": ["05"]}
idmt_guitar_splits = {
"train": ["classical", "country_folk", "jazz", "latin", "metal", "pop"],
"val": ["reggae", "ska"],
"test": ["rock", "blues"],
}
idmt_bass_splits = {
"train": ["BE", "BEQ"],
"val": ["VIF"],
"test": ["VIS"],
}
dsd_100_splits = {
"train": ["train"],
"val": ["val"],
"test": ["test"],
}
idmt_drums_splits = {
"train": ["WaveDrum02", "TechnoDrum01"],
"val": ["RealDrum01"],
"test": ["TechnoDrum02", "WaveDrum01"],
}
def locate_files(root: str, mode: str):
file_list = []
# ------------------------- VocalSet -------------------------
vocalset_dir = os.path.join(root, "VocalSet1-2")
if os.path.isdir(vocalset_dir):
# find all singer directories
singer_dirs = glob.glob(os.path.join(vocalset_dir, "data_by_singer", "*"))
singer_dirs = [
sd for sd in singer_dirs if os.path.basename(sd) in vocalset_splits[mode]
]
files = []
for singer_dir in singer_dirs:
files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
print(f"Found {len(files)} files in VocalSet {mode}.")
file_list.append(sorted(files))
# ------------------------- GuitarSet -------------------------
guitarset_dir = os.path.join(root, "audio_mono-mic")
if os.path.isdir(guitarset_dir):
files = glob.glob(os.path.join(guitarset_dir, "*.wav"))
files = [
f
for f in files
if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
]
print(f"Found {len(files)} files in GuitarSet {mode}.")
file_list.append(sorted(files))
# # ------------------------- IDMT-SMT-GUITAR -------------------------
# idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
# if os.path.isdir(idmt_smt_guitar_dir):
# files = glob.glob(
# os.path.join(
# idmt_smt_guitar_dir, "IDMT-SMT-GUITAR_V2", "dataset4", "**", "*.wav"
# ),
# recursive=True,
# )
# files = [
# f
# for f in files
# if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
# ]
# file_list.append(sorted(files))
# print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
# ------------------------- IDMT-SMT-BASS -------------------------
# idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
# if os.path.isdir(idmt_smt_bass_dir):
# files = glob.glob(
# os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
# recursive=True,
# )
# files = [
# f
# for f in files
# if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
# ]
# file_list.append(sorted(files))
# print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
# ------------------------- DSD100 ---------------------------------
dsd_100_dir = os.path.join(root, "DSD100")
if os.path.isdir(dsd_100_dir):
files = glob.glob(
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
recursive=True,
)
file_list.append(sorted(files))
print(f"Found {len(files)} files in DSD100 {mode}.")
# ------------------------- IDMT-SMT-DRUMS -------------------------
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
if os.path.isdir(idmt_smt_drums_dir):
files = glob.glob(os.path.join(idmt_smt_drums_dir, "audio", "*.wav"))
files = [
f
for f in files
if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
]
file_list.append(sorted(files))
print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
return file_list
class EffectDataset(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size: int = 262144,
total_chunks: int = 1000,
effect_modules: List[Dict[str, torch.nn.Module]] = None,
effects_to_keep: List[str] = None,
effects_to_remove: List[str] = None,
num_kept_effects: List[int] = [1, 5],
num_removed_effects: List[int] = [1, 5],
shuffle_kept_effects: bool = True,
shuffle_removed_effects: bool = False,
render_files: bool = True,
render_root: str = None,
mode: str = "train",
):
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.render_root = Path(render_root)
self.chunk_size = chunk_size
self.total_chunks = total_chunks
self.sample_rate = sample_rate
self.mode = mode
self.num_kept_effects = num_kept_effects
self.num_removed_effects = num_removed_effects
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.effects = effect_modules
self.shuffle_kept_effects = shuffle_kept_effects
self.shuffle_removed_effects = shuffle_removed_effects
effects_string = "_".join(
self.effects_to_keep
+ ["_"]
+ self.effects_to_remove
+ ["_"]
+ [str(x) for x in num_kept_effects]
+ ["_"]
+ [str(x) for x in num_removed_effects]
)
self.validate_effect_input()
self.proc_root = self.render_root / "processed" / effects_string / self.mode
self.files = locate_files(self.root, self.mode)
if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
print("Found processed files.")
if render_files:
re_render = input(
"WARNING: By default, will re-render files.\n"
"Set render_files=False to skip re-rendering.\n"
"Are you sure you want to re-render? (y/n): "
)
if re_render != "y":
sys.exit()
shutil.rmtree(self.proc_root)
print("Total datasets:", len(self.files))
print("Processing files...")
if render_files:
# Split audio file into chunks, resample, then apply random effects
self.proc_root.mkdir(parents=True, exist_ok=True)
for num_chunk in tqdm(range(self.total_chunks)):
chunk = None
random_dataset_choice = random.choice(self.files)
while chunk is None:
random_file_choice = random.choice(random_dataset_choice)
chunk = select_random_chunk(
random_file_choice, self.chunk_size, self.sample_rate
)
# Sum to mono
if chunk.shape[0] > 1:
chunk = chunk.sum(0, keepdim=True)
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
output_dir = self.proc_root / str(num_chunk)
output_dir.mkdir(exist_ok=True)
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
torch.save(dry_effects, output_dir / "dry_effects.pt")
torch.save(wet_effects, output_dir / "wet_effects.pt")
print("Finished rendering")
else:
self.total_chunks = len(list(self.proc_root.iterdir()))
print("Total chunks:", self.total_chunks)
def __len__(self):
return self.total_chunks
def __getitem__(self, idx):
input_file = self.proc_root / str(idx) / "input.wav"
target_file = self.proc_root / str(idx) / "target.wav"
dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt")
wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt")
input, sr = torchaudio.load(input_file)
target, sr = torchaudio.load(target_file)
return (input, target, dry_effect_names, wet_effect_names)
def validate_effect_input(self):
for effect in self.effects.values():
if type(effect) not in ALL_EFFECTS:
raise ValueError(
f"Effect {effect} not found in ALL_EFFECTS. "
f"Please choose from {ALL_EFFECTS}"
)
for effect in self.effects_to_keep:
if effect not in self.effects.keys():
raise ValueError(
f"Effect {effect} not found in self.effects. "
f"Please choose from {self.effects.keys()}"
)
for effect in self.effects_to_remove:
if effect not in self.effects.keys():
raise ValueError(
f"Effect {effect} not found in self.effects. "
f"Please choose from {self.effects.keys()}"
)
kept_str = "randomly" if self.shuffle_kept_effects else "in order"
rem_str = "randomly" if self.shuffle_removed_effects else "in order"
if self.num_kept_effects[0] > self.num_kept_effects[1]:
raise ValueError(
f"num_kept_effects must be a tuple of (min, max). "
f"Got {self.num_kept_effects}"
)
if self.num_kept_effects[0] == self.num_kept_effects[1]:
num_kept_str = f"{self.num_kept_effects[0]}"
else:
num_kept_str = (
f"Between {self.num_kept_effects[0]}-{self.num_kept_effects[1]}"
)
if self.num_removed_effects[0] > self.num_removed_effects[1]:
raise ValueError(
f"num_removed_effects must be a tuple of (min, max). "
f"Got {self.num_removed_effects}"
)
if self.num_removed_effects[0] == self.num_removed_effects[1]:
num_rem_str = f"{self.num_removed_effects[0]}"
else:
num_rem_str = (
f"Between {self.num_removed_effects[0]}-{self.num_removed_effects[1]}"
)
rem_fx = self.effects_to_remove
kept_fx = self.effects_to_keep
print(
f"Effect Summary: \n"
f"Apply kept effects: {kept_fx} ({num_kept_str}, chosen {kept_str}) -> Dry\n"
f"Apply remove effects: {rem_fx} ({num_rem_str}, chosen {rem_str}) -> Wet\n"
)
def process_effects(self, dry: torch.Tensor):
# Apply Kept Effects
# Shuffle effects if specified
if self.shuffle_kept_effects:
effect_indices = torch.randperm(len(self.effects_to_keep))
else:
effect_indices = torch.arange(len(self.effects_to_keep))
r1 = self.num_kept_effects[0]
r2 = self.num_kept_effects[1]
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_kept_effects]
# Index in effect settings
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
# Apply
dry_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
dry = self.normalize(effect(dry))
dry_labels.append(ALL_EFFECTS.index(type(effect)))
# Apply effects_to_remove
# Shuffle effects if specified
if self.shuffle_removed_effects:
effect_indices = torch.randperm(len(self.effects_to_remove))
else:
effect_indices = torch.arange(len(self.effects_to_remove))
wet = torch.clone(dry)
r1 = self.num_removed_effects[0]
r2 = self.num_removed_effects[1]
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_removed_effects]
# Index in effect settings
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
# Apply
wet_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
wet = self.normalize(effect(wet))
wet_labels.append(ALL_EFFECTS.index(type(effect)))
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
for label_idx in wet_labels:
wet_labels_tensor[label_idx] = 1.0
for label_idx in dry_labels:
dry_labels_tensor[label_idx] = 1.0
# Normalize
normalized_dry = self.normalize(dry)
normalized_wet = self.normalize(wet)
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
class EffectDatamodule(pl.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
test_dataset,
*,
batch_size: int,
num_workers: int,
pin_memory: bool = False,
**kwargs: int,
) -> None:
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: Any = None) -> None:
pass
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)