RemFx / remfx /datasets.py
mattricesound's picture
Render effected chunks to avoid bottlenecks
7173e65
raw
history blame
13.8 kB
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
import torchaudio.transforms as T
import torch.nn.functional as F
from pathlib import Path
import pytorch_lightning as pl
from typing import Any, List, Tuple
from remfx import effects
from pedalboard import (
Pedalboard,
Chorus,
Reverb,
Compressor,
Phaser,
Delay,
Distortion,
Limiter,
)
from tqdm import tqdm
# https://zenodo.org/record/7044411/ -> GuitarFX
# https://zenodo.org/record/3371780 -> GuitarSet
# https://zenodo.org/record/1193957 -> VocalSet
deterministic_effects = {
"Distortion": Pedalboard([Distortion()]),
"Compressor": Pedalboard([Compressor()]),
"Chorus": Pedalboard([Chorus()]),
"Phaser": Pedalboard([Phaser()]),
"Delay": Pedalboard([Delay()]),
"Reverb": Pedalboard([Reverb()]),
"Limiter": Pedalboard([Limiter()]),
}
class GuitarFXDataset(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size_in_sec: int = 3,
effect_types: List[str] = None,
):
super().__init__()
self.wet_files = []
self.dry_files = []
self.chunks = []
self.labels = []
self.song_idx = []
self.root = Path(root)
self.chunk_size_in_sec = chunk_size_in_sec
self.sample_rate = sample_rate
if effect_types is None:
effect_types = [
d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
]
current_file = 0
for i, effect in enumerate(effect_types):
for pickup in Path(self.root / effect).iterdir():
wet_files = sorted(list(pickup.glob("*.wav")))
dry_files = sorted(
list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
)
self.wet_files += wet_files
self.dry_files += dry_files
self.labels += [i] * len(wet_files)
for audio_file in wet_files:
chunk_starts, orig_sr = create_sequential_chunks(
audio_file, self.chunk_size_in_sec
)
self.chunks += chunk_starts
self.song_idx += [current_file] * len(chunk_starts)
current_file += 1
print(
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
f"Total chunks: {len(self.chunks)}"
)
self.resampler = T.Resample(orig_sr, sample_rate)
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
# Load effected and "clean" audio
song_idx = self.song_idx[idx]
x, sr = torchaudio.load(self.wet_files[song_idx])
y, sr = torchaudio.load(self.dry_files[song_idx])
effect_label = self.labels[song_idx] # Effect label
chunk_start = self.chunks[idx]
chunk_size_in_samples = self.chunk_size_in_sec * sr
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
resampled_x = self.resampler(x)
resampled_y = self.resampler(y)
# Reset chunk size to be new sample rate
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
# Pad to chunk_size if needed
if resampled_x.shape[-1] < chunk_size_in_samples:
resampled_x = F.pad(
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
)
if resampled_y.shape[-1] < chunk_size_in_samples:
resampled_y = F.pad(
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
)
return (resampled_x, resampled_y, effect_label)
class GuitarSet(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size_in_sec: int = 3,
effect_types: List[torch.nn.Module] = None,
):
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.chunk_size_in_sec = chunk_size_in_sec
self.files = sorted(list(self.root.glob("./**/*.wav")))
self.sample_rate = sample_rate
for i, audio_file in enumerate(self.files):
chunk_starts, orig_sr = create_sequential_chunks(
audio_file, self.chunk_size_in_sec
)
self.chunks += chunk_starts
self.song_idx += [i] * len(chunk_starts)
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
self.resampler = T.Resample(orig_sr, sample_rate)
self.effect_types = effect_types
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.mode = "train"
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
# Load and effect audio
song_idx = self.song_idx[idx]
x, sr = torchaudio.load(self.files[song_idx])
chunk_start = self.chunks[idx]
chunk_size_in_samples = self.chunk_size_in_sec * sr
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
resampled_x = self.resampler(x)
# Reset chunk size to be new sample rate
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
# Pad to chunk_size if needed
if resampled_x.shape[-1] < chunk_size_in_samples:
resampled_x = F.pad(
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
)
# Add random effect if train
if self.mode == "train":
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
effect = self.effect_types[effect_name]
effected_input = effect(resampled_x)
else:
# deterministic static effect for eval
effect_idx = idx % len(self.effect_types.keys())
effect_name = list(self.effect_types.keys())[effect_idx]
effect = deterministic_effects[effect_name]
effected_input = torch.from_numpy(
effect(resampled_x.numpy(), self.sample_rate)
)
normalized_input = self.normalize(effected_input)
normalized_target = self.normalize(resampled_x)
return (normalized_input, normalized_target, effect_name)
class VocalSet(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size_in_sec: int = 3,
effect_types: List[torch.nn.Module] = None,
render_files: bool = True,
output_root: str = "processed",
mode: str = "train",
):
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.chunk_size_in_sec = chunk_size_in_sec
self.sample_rate = sample_rate
self.mode = mode
mode_path = self.root / self.mode
self.files = sorted(list(mode_path.glob("./**/*.wav")))
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.effect_types = effect_types
self.output_root = Path(output_root)
output_mode_path = output_root / self.mode
self.num_chunks = 0
print("Total files:", len(self.files))
print("Processing files...")
if render_files:
if not output_root.exists():
output_root.mkdir()
if not output_mode_path.exists():
output_mode_path.mkdir()
for i, audio_file in tqdm(enumerate(self.files)):
chunks, orig_sr = create_sequential_chunks(
audio_file, self.chunk_size_in_sec
)
for chunk in chunks:
resampled_chunk = torchaudio.functional.resample(
chunk, orig_sr, sample_rate
)
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
if resampled_chunk.shape[-1] < chunk_size_in_samples:
resampled_chunk = F.pad(
resampled_chunk,
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
)
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
effect_name = list(self.effect_types.keys())[int(effect_idx)]
effect = self.effect_types[effect_name]
effected_input = effect(resampled_chunk)
normalized_input = self.normalize(effected_input)
normalized_target = self.normalize(resampled_chunk)
output_dir = output_mode_path / str(self.num_chunks)
output_dir.mkdir(exist_ok=True)
torchaudio.save(
output_dir / "input.wav", normalized_input, self.sample_rate
)
torchaudio.save(
output_dir / "target.wav", normalized_target, self.sample_rate
)
self.num_chunks += 1
else:
self.num_chunks = len(list(output_mode_path.glob("./**/*.wav")))
print(
f"Found {len(self.files)} {self.mode} files .\n"
f"Total chunks: {self.num_chunks}"
)
def __len__(self):
return self.num_chunks
def __getitem__(self, idx):
# Load audio
input_file = self.root / "processed" / self.mode / str(idx) / "input.wav"
target_file = self.root / "processed" / self.mode / str(idx) / "target.wav"
input, sr = torchaudio.load(input_file)
target, sr = torchaudio.load(target_file)
return (input, target, "")
def create_random_chunks(
audio_file: str, chunk_size: int, num_chunks: int
) -> Tuple[List[Tuple[int, int]], int]:
"""Create num_chunks random chunks of size chunk_size (seconds)
from an audio file.
Return sample_index of start of each chunk and original sr
"""
audio, sr = torchaudio.load(audio_file)
chunk_size_in_samples = chunk_size * sr
if chunk_size_in_samples >= audio.shape[-1]:
chunk_size_in_samples = audio.shape[-1] - 1
chunks = []
for i in range(num_chunks):
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
chunks.append(start)
return chunks, sr
def create_sequential_chunks(
audio_file: str, chunk_size: int
) -> Tuple[List[Tuple[int, int]], int]:
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
Return sample_index of start of each chunk and original sr
"""
chunks = []
audio, sr = torchaudio.load(audio_file)
chunk_size_in_samples = chunk_size * sr
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
for start in chunk_starts:
if start + chunk_size_in_samples > audio.shape[-1]:
break
chunks.append(audio[:, start : start + chunk_size_in_samples])
return chunks, sr
class Datamodule(pl.LightningDataModule):
def __init__(
self,
dataset,
*,
val_split: float,
batch_size: int,
num_workers: int,
pin_memory: bool = False,
**kwargs: int,
) -> None:
super().__init__()
self.dataset = dataset
self.val_split = val_split
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.data_train: Any = None
self.data_val: Any = None
def setup(self, stage: Any = None) -> None:
split = [1.0 - self.val_split, self.val_split]
train_size = round(split[0] * len(self.dataset))
val_size = round(split[1] * len(self.dataset))
self.data_train, self.data_val = random_split(
self.dataset, [train_size, val_size]
)
self.data_val.dataset.mode = "val"
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.data_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
class VocalSetDatamodule(pl.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
*,
batch_size: int,
num_workers: int,
pin_memory: bool = False,
**kwargs: int,
) -> None:
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: Any = None) -> None:
pass
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)