AlexK-PL's picture
Upload 72 files
c61c48a
raw
history blame
2.38 kB
import os
import glob
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from utils.utils import read_wav_np
def create_dataloader(hp, args, train):
dataset = MelFromDisk(hp, args, train)
if train:
return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True,
num_workers=hp.train.num_workers, pin_memory=True, drop_last=True)
else:
return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
num_workers=hp.train.num_workers, pin_memory=True, drop_last=False)
class MelFromDisk(Dataset):
def __init__(self, hp, args, train):
self.hp = hp
self.args = args
self.train = train
self.path = hp.data.train if train else hp.data.validation
self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True)
self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2
self.mapping = [i for i in range(len(self.wav_list))]
def __len__(self):
return len(self.wav_list)
def __getitem__(self, idx):
if self.train:
idx1 = idx
idx2 = self.mapping[idx1]
return self.my_getitem(idx1), self.my_getitem(idx2)
else:
return self.my_getitem(idx)
def shuffle_mapping(self):
random.shuffle(self.mapping)
def my_getitem(self, idx):
wavpath = self.wav_list[idx]
melpath = wavpath.replace('.wav', '.mel')
sr, audio = read_wav_np(wavpath)
if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short:
audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
audio = torch.from_numpy(audio).unsqueeze(0)
mel = torch.load(melpath).squeeze(0)
if self.train:
max_mel_start = mel.size(1) - self.mel_segment_length
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.mel_segment_length
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hp.audio.hop_length
audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length]
audio = audio + (1/32768) * torch.randn_like(audio)
return mel, audio