RemFx / remfx /datasets.py
mattricesound's picture
Clean project. Add 'all effects' to experiments
8125531
raw
history blame
5.28 kB
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torch.nn.functional as F
from pathlib import Path
import pytorch_lightning as pl
from typing import Any, List
from remfx import effects
from tqdm import tqdm
from remfx.utils import create_sequential_chunks
# https://zenodo.org/record/1193957 -> VocalSet
class VocalSet(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size_in_sec: int = 3,
effect_types: List[torch.nn.Module] = None,
render_files: bool = True,
mode: str = "train",
):
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.chunk_size_in_sec = chunk_size_in_sec
self.sample_rate = sample_rate
self.mode = mode
mode_path = self.root / self.mode
self.files = sorted(list(mode_path.glob("./**/*.wav")))
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.effect_types = effect_types
self.processed_root = self.root / "processed" / self.mode
self.num_chunks = 0
print("Total files:", len(self.files))
print("Processing files...")
if render_files:
# Split audio file into chunks, resample, then apply random effects
self.processed_root.mkdir(parents=True, exist_ok=True)
for audio_file in tqdm(self.files, total=len(self.files)):
chunks, orig_sr = create_sequential_chunks(
audio_file, self.chunk_size_in_sec
)
for chunk in chunks:
resampled_chunk = torchaudio.functional.resample(
chunk, orig_sr, sample_rate
)
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
if resampled_chunk.shape[-1] < chunk_size_in_samples:
resampled_chunk = F.pad(
resampled_chunk,
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
)
# Apply effect
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
effect_name = list(self.effect_types.keys())[int(effect_idx)]
effect = self.effect_types[effect_name]
effected_input = effect(resampled_chunk)
# Normalize
normalized_input = self.normalize(effected_input)
normalized_target = self.normalize(resampled_chunk)
output_dir = self.processed_root / str(self.num_chunks)
output_dir.mkdir(exist_ok=True)
torchaudio.save(
output_dir / "input.wav", normalized_input, self.sample_rate
)
torchaudio.save(
output_dir / "target.wav", normalized_target, self.sample_rate
)
torch.save(effect_name, output_dir / "effect_name.pt")
self.num_chunks += 1
else:
self.num_chunks = len(list(self.processed_root.iterdir()))
print(
f"Found {len(self.files)} {self.mode} files .\n"
f"Total chunks: {self.num_chunks}"
)
def __len__(self):
return self.num_chunks
def __getitem__(self, idx):
input_file = self.processed_root / str(idx) / "input.wav"
target_file = self.processed_root / str(idx) / "target.wav"
effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
input, sr = torchaudio.load(input_file)
target, sr = torchaudio.load(target_file)
return (input, target, effect_name)
class VocalSetDatamodule(pl.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
test_dataset,
*,
batch_size: int,
num_workers: int,
pin_memory: bool = False,
**kwargs: int,
) -> None:
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: Any = None) -> None:
pass
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)