File size: 2,128 Bytes
14ae0ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torch.utils.data import Dataset
import torchaudio
import torchaudio.transforms as T
import torch.nn.functional as F
from pathlib import Path
from typing import List

# https://zenodo.org/record/7044411/

LENGTH = 2**18  # 12 seconds
ORIG_SR = 48000


class GuitarFXDataset(Dataset):
    def __init__(
        self,
        root: str,
        sample_rate: int,
        length: int = LENGTH,
        effect_type: List[str] = None,
    ):
        self.length = length
        self.wet_files = []
        self.dry_files = []
        self.labels = []
        self.root = Path(root)
        if effect_type is None:
            effect_type = [
                d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
            ]
        for i, effect in enumerate(effect_type):
            for pickup in Path(self.root / effect).iterdir():
                self.wet_files += list(pickup.glob("*.wav"))
                self.dry_files += list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
                self.labels += [i] * len(self.wet_files)
        print(
            f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files"
        )
        self.resampler = T.Resample(ORIG_SR, sample_rate)

    def __len__(self):
        return len(self.dry_files)

    def __getitem__(self, idx):
        x, sr = torchaudio.load(self.wet_files[idx])
        y, sr = torchaudio.load(self.dry_files[idx])
        effect_label = self.labels[idx]

        resampled_x = self.resampler(x)
        resampled_y = self.resampler(y)
        # Pad or crop to length
        if resampled_x.shape[-1] < self.length:
            resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))
        elif resampled_x.shape[-1] > self.length:
            resampled_x = resampled_x[:, : self.length]
        if resampled_y.shape[-1] < self.length:
            resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))
        elif resampled_y.shape[-1] > self.length:
            resampled_y = resampled_y[:, : self.length]
        return (resampled_x, resampled_y, effect_label)