Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset, DataLoader, random_split | |
import torchaudio | |
import torchaudio.transforms as T | |
import torch.nn.functional as F | |
from pathlib import Path | |
import pytorch_lightning as pl | |
from typing import Any, List, Tuple | |
from remfx import effects | |
from pedalboard import ( | |
Pedalboard, | |
Chorus, | |
Reverb, | |
Compressor, | |
Phaser, | |
Delay, | |
Distortion, | |
Limiter, | |
) | |
from tqdm import tqdm | |
# https://zenodo.org/record/7044411/ -> GuitarFX | |
# https://zenodo.org/record/3371780 -> GuitarSet | |
# https://zenodo.org/record/1193957 -> VocalSet | |
deterministic_effects = { | |
"Distortion": Pedalboard([Distortion()]), | |
"Compressor": Pedalboard([Compressor()]), | |
"Chorus": Pedalboard([Chorus()]), | |
"Phaser": Pedalboard([Phaser()]), | |
"Delay": Pedalboard([Delay()]), | |
"Reverb": Pedalboard([Reverb()]), | |
"Limiter": Pedalboard([Limiter()]), | |
} | |
class GuitarFXDataset(Dataset): | |
def __init__( | |
self, | |
root: str, | |
sample_rate: int, | |
chunk_size_in_sec: int = 3, | |
effect_types: List[str] = None, | |
): | |
super().__init__() | |
self.wet_files = [] | |
self.dry_files = [] | |
self.chunks = [] | |
self.labels = [] | |
self.song_idx = [] | |
self.root = Path(root) | |
self.chunk_size_in_sec = chunk_size_in_sec | |
self.sample_rate = sample_rate | |
if effect_types is None: | |
effect_types = [ | |
d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean" | |
] | |
current_file = 0 | |
for i, effect in enumerate(effect_types): | |
for pickup in Path(self.root / effect).iterdir(): | |
wet_files = sorted(list(pickup.glob("*.wav"))) | |
dry_files = sorted( | |
list(self.root.glob(f"Clean/{pickup.name}/**/*.wav")) | |
) | |
self.wet_files += wet_files | |
self.dry_files += dry_files | |
self.labels += [i] * len(wet_files) | |
for audio_file in wet_files: | |
chunk_starts, orig_sr = create_sequential_chunks( | |
audio_file, self.chunk_size_in_sec | |
) | |
self.chunks += chunk_starts | |
self.song_idx += [current_file] * len(chunk_starts) | |
current_file += 1 | |
print( | |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n" | |
f"Total chunks: {len(self.chunks)}" | |
) | |
self.resampler = T.Resample(orig_sr, sample_rate) | |
def __len__(self): | |
return len(self.chunks) | |
def __getitem__(self, idx): | |
# Load effected and "clean" audio | |
song_idx = self.song_idx[idx] | |
x, sr = torchaudio.load(self.wet_files[song_idx]) | |
y, sr = torchaudio.load(self.dry_files[song_idx]) | |
effect_label = self.labels[song_idx] # Effect label | |
chunk_start = self.chunks[idx] | |
chunk_size_in_samples = self.chunk_size_in_sec * sr | |
x = x[:, chunk_start : chunk_start + chunk_size_in_samples] | |
y = y[:, chunk_start : chunk_start + chunk_size_in_samples] | |
resampled_x = self.resampler(x) | |
resampled_y = self.resampler(y) | |
# Reset chunk size to be new sample rate | |
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate | |
# Pad to chunk_size if needed | |
if resampled_x.shape[-1] < chunk_size_in_samples: | |
resampled_x = F.pad( | |
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1]) | |
) | |
if resampled_y.shape[-1] < chunk_size_in_samples: | |
resampled_y = F.pad( | |
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1]) | |
) | |
return (resampled_x, resampled_y, effect_label) | |
class GuitarSet(Dataset): | |
def __init__( | |
self, | |
root: str, | |
sample_rate: int, | |
chunk_size_in_sec: int = 3, | |
effect_types: List[torch.nn.Module] = None, | |
): | |
super().__init__() | |
self.chunks = [] | |
self.song_idx = [] | |
self.root = Path(root) | |
self.chunk_size_in_sec = chunk_size_in_sec | |
self.files = sorted(list(self.root.glob("./**/*.wav"))) | |
self.sample_rate = sample_rate | |
for i, audio_file in enumerate(self.files): | |
chunk_starts, orig_sr = create_sequential_chunks( | |
audio_file, self.chunk_size_in_sec | |
) | |
self.chunks += chunk_starts | |
self.song_idx += [i] * len(chunk_starts) | |
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}") | |
self.resampler = T.Resample(orig_sr, sample_rate) | |
self.effect_types = effect_types | |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20) | |
self.mode = "train" | |
def __len__(self): | |
return len(self.chunks) | |
def __getitem__(self, idx): | |
# Load and effect audio | |
song_idx = self.song_idx[idx] | |
x, sr = torchaudio.load(self.files[song_idx]) | |
chunk_start = self.chunks[idx] | |
chunk_size_in_samples = self.chunk_size_in_sec * sr | |
x = x[:, chunk_start : chunk_start + chunk_size_in_samples] | |
resampled_x = self.resampler(x) | |
# Reset chunk size to be new sample rate | |
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate | |
# Pad to chunk_size if needed | |
if resampled_x.shape[-1] < chunk_size_in_samples: | |
resampled_x = F.pad( | |
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1]) | |
) | |
# Add random effect if train | |
if self.mode == "train": | |
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys()) | |
effect_name = list(self.effect_types.keys())[int(random_effect_idx)] | |
effect = self.effect_types[effect_name] | |
effected_input = effect(resampled_x) | |
else: | |
# deterministic static effect for eval | |
effect_idx = idx % len(self.effect_types.keys()) | |
effect_name = list(self.effect_types.keys())[effect_idx] | |
effect = deterministic_effects[effect_name] | |
effected_input = torch.from_numpy( | |
effect(resampled_x.numpy(), self.sample_rate) | |
) | |
normalized_input = self.normalize(effected_input) | |
normalized_target = self.normalize(resampled_x) | |
return (normalized_input, normalized_target, effect_name) | |
class VocalSet(Dataset): | |
def __init__( | |
self, | |
root: str, | |
sample_rate: int, | |
chunk_size_in_sec: int = 3, | |
effect_types: List[torch.nn.Module] = None, | |
render_files: bool = True, | |
output_root: str = "processed", | |
mode: str = "train", | |
): | |
super().__init__() | |
self.chunks = [] | |
self.song_idx = [] | |
self.root = Path(root) | |
self.chunk_size_in_sec = chunk_size_in_sec | |
self.sample_rate = sample_rate | |
self.mode = mode | |
mode_path = self.root / self.mode | |
self.files = sorted(list(mode_path.glob("./**/*.wav"))) | |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20) | |
self.effect_types = effect_types | |
self.output_root = Path(output_root) | |
output_mode_path = output_root / self.mode | |
self.num_chunks = 0 | |
print("Total files:", len(self.files)) | |
print("Processing files...") | |
if render_files: | |
if not output_root.exists(): | |
output_root.mkdir() | |
if not output_mode_path.exists(): | |
output_mode_path.mkdir() | |
for i, audio_file in tqdm(enumerate(self.files)): | |
chunks, orig_sr = create_sequential_chunks( | |
audio_file, self.chunk_size_in_sec | |
) | |
for chunk in chunks: | |
resampled_chunk = torchaudio.functional.resample( | |
chunk, orig_sr, sample_rate | |
) | |
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate | |
if resampled_chunk.shape[-1] < chunk_size_in_samples: | |
resampled_chunk = F.pad( | |
resampled_chunk, | |
(0, chunk_size_in_samples - resampled_chunk.shape[1]), | |
) | |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys()) | |
effect_name = list(self.effect_types.keys())[int(effect_idx)] | |
effect = self.effect_types[effect_name] | |
effected_input = effect(resampled_chunk) | |
normalized_input = self.normalize(effected_input) | |
normalized_target = self.normalize(resampled_chunk) | |
output_dir = output_mode_path / str(self.num_chunks) | |
output_dir.mkdir(exist_ok=True) | |
torchaudio.save( | |
output_dir / "input.wav", normalized_input, self.sample_rate | |
) | |
torchaudio.save( | |
output_dir / "target.wav", normalized_target, self.sample_rate | |
) | |
self.num_chunks += 1 | |
else: | |
self.num_chunks = len(list(output_mode_path.glob("./**/*.wav"))) | |
print( | |
f"Found {len(self.files)} {self.mode} files .\n" | |
f"Total chunks: {self.num_chunks}" | |
) | |
def __len__(self): | |
return self.num_chunks | |
def __getitem__(self, idx): | |
# Load audio | |
input_file = self.root / "processed" / self.mode / str(idx) / "input.wav" | |
target_file = self.root / "processed" / self.mode / str(idx) / "target.wav" | |
input, sr = torchaudio.load(input_file) | |
target, sr = torchaudio.load(target_file) | |
return (input, target, "") | |
def create_random_chunks( | |
audio_file: str, chunk_size: int, num_chunks: int | |
) -> Tuple[List[Tuple[int, int]], int]: | |
"""Create num_chunks random chunks of size chunk_size (seconds) | |
from an audio file. | |
Return sample_index of start of each chunk and original sr | |
""" | |
audio, sr = torchaudio.load(audio_file) | |
chunk_size_in_samples = chunk_size * sr | |
if chunk_size_in_samples >= audio.shape[-1]: | |
chunk_size_in_samples = audio.shape[-1] - 1 | |
chunks = [] | |
for i in range(num_chunks): | |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item() | |
chunks.append(start) | |
return chunks, sr | |
def create_sequential_chunks( | |
audio_file: str, chunk_size: int | |
) -> Tuple[List[Tuple[int, int]], int]: | |
"""Create sequential chunks of size chunk_size (seconds) from an audio file. | |
Return sample_index of start of each chunk and original sr | |
""" | |
chunks = [] | |
audio, sr = torchaudio.load(audio_file) | |
chunk_size_in_samples = chunk_size * sr | |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples) | |
for start in chunk_starts: | |
if start + chunk_size_in_samples > audio.shape[-1]: | |
break | |
chunks.append(audio[:, start : start + chunk_size_in_samples]) | |
return chunks, sr | |
class Datamodule(pl.LightningDataModule): | |
def __init__( | |
self, | |
dataset, | |
*, | |
val_split: float, | |
batch_size: int, | |
num_workers: int, | |
pin_memory: bool = False, | |
**kwargs: int, | |
) -> None: | |
super().__init__() | |
self.dataset = dataset | |
self.val_split = val_split | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.data_train: Any = None | |
self.data_val: Any = None | |
def setup(self, stage: Any = None) -> None: | |
split = [1.0 - self.val_split, self.val_split] | |
train_size = round(split[0] * len(self.dataset)) | |
val_size = round(split[1] * len(self.dataset)) | |
self.data_train, self.data_val = random_split( | |
self.dataset, [train_size, val_size] | |
) | |
self.data_val.dataset.mode = "val" | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_train, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=True, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_val, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
) | |
class VocalSetDatamodule(pl.LightningDataModule): | |
def __init__( | |
self, | |
train_dataset, | |
val_dataset, | |
*, | |
batch_size: int, | |
num_workers: int, | |
pin_memory: bool = False, | |
**kwargs: int, | |
) -> None: | |
super().__init__() | |
self.train_dataset = train_dataset | |
self.val_dataset = val_dataset | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
def setup(self, stage: Any = None) -> None: | |
pass | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.train_dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=True, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.val_dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
) | |