Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset, DataLoader, random_split | |
import torchaudio | |
import torchaudio.transforms as T | |
import torch.nn.functional as F | |
from pathlib import Path | |
import pytorch_lightning as pl | |
from typing import Any, List | |
# https://zenodo.org/record/7044411/ | |
LENGTH = 2**18 # 12 seconds | |
ORIG_SR = 48000 | |
class GuitarFXDataset(Dataset): | |
def __init__( | |
self, | |
root: str, | |
sample_rate: int, | |
length: int = LENGTH, | |
effect_types: List[str] = None, | |
): | |
self.length = length | |
self.wet_files = [] | |
self.dry_files = [] | |
self.labels = [] | |
self.root = Path(root) | |
if effect_types is None: | |
effect_types = [ | |
d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean" | |
] | |
for i, effect in enumerate(effect_types): | |
for pickup in Path(self.root / effect).iterdir(): | |
self.wet_files += sorted(list(pickup.glob("*.wav"))) | |
self.dry_files += sorted( | |
list(self.root.glob(f"Clean/{pickup.name}/**/*.wav")) | |
) | |
self.labels += [i] * len(self.wet_files) | |
print( | |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files" | |
) | |
self.resampler = T.Resample(ORIG_SR, sample_rate) | |
def __len__(self): | |
return len(self.dry_files) | |
def __getitem__(self, idx): | |
x, sr = torchaudio.load(self.wet_files[idx]) | |
y, sr = torchaudio.load(self.dry_files[idx]) | |
effect_label = self.labels[idx] | |
resampled_x = self.resampler(x) | |
resampled_y = self.resampler(y) | |
# Pad or crop to length | |
if resampled_x.shape[-1] < self.length: | |
resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1])) | |
elif resampled_x.shape[-1] > self.length: | |
resampled_x = resampled_x[:, : self.length] | |
if resampled_y.shape[-1] < self.length: | |
resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1])) | |
elif resampled_y.shape[-1] > self.length: | |
resampled_y = resampled_y[:, : self.length] | |
return (resampled_x, resampled_y, effect_label) | |
class Datamodule(pl.LightningDataModule): | |
def __init__( | |
self, | |
dataset, | |
*, | |
val_split: float, | |
batch_size: int, | |
num_workers: int, | |
pin_memory: bool = False, | |
**kwargs: int, | |
) -> None: | |
super().__init__() | |
self.dataset = dataset | |
self.val_split = val_split | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.data_train: Any = None | |
self.data_val: Any = None | |
def setup(self, stage: Any = None) -> None: | |
split = [1.0 - self.val_split, self.val_split] | |
train_size = int(split[0] * len(self.dataset)) | |
val_size = int(split[1] * len(self.dataset)) | |
self.data_train, self.data_val = random_split( | |
self.dataset, [train_size, val_size] | |
) | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_train, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=True, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.data_val, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
) | |