Spaces:
Sleeping
Sleeping
import os | |
import glob | |
import torch | |
import random | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader | |
from utils.utils import read_wav_np | |
def create_dataloader(hp, args, train): | |
dataset = MelFromDisk(hp, args, train) | |
if train: | |
return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True, | |
num_workers=hp.train.num_workers, pin_memory=True, drop_last=True) | |
else: | |
return DataLoader(dataset=dataset, batch_size=1, shuffle=False, | |
num_workers=hp.train.num_workers, pin_memory=True, drop_last=False) | |
class MelFromDisk(Dataset): | |
def __init__(self, hp, args, train): | |
self.hp = hp | |
self.args = args | |
self.train = train | |
self.path = hp.data.train if train else hp.data.validation | |
self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True) | |
self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2 | |
self.mapping = [i for i in range(len(self.wav_list))] | |
def __len__(self): | |
return len(self.wav_list) | |
def __getitem__(self, idx): | |
if self.train: | |
idx1 = idx | |
idx2 = self.mapping[idx1] | |
return self.my_getitem(idx1), self.my_getitem(idx2) | |
else: | |
return self.my_getitem(idx) | |
def shuffle_mapping(self): | |
random.shuffle(self.mapping) | |
def my_getitem(self, idx): | |
wavpath = self.wav_list[idx] | |
melpath = wavpath.replace('.wav', '.mel') | |
sr, audio = read_wav_np(wavpath) | |
if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short: | |
audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \ | |
mode='constant', constant_values=0.0) | |
audio = torch.from_numpy(audio).unsqueeze(0) | |
mel = torch.load(melpath).squeeze(0) | |
if self.train: | |
max_mel_start = mel.size(1) - self.mel_segment_length | |
mel_start = random.randint(0, max_mel_start) | |
mel_end = mel_start + self.mel_segment_length | |
mel = mel[:, mel_start:mel_end] | |
audio_start = mel_start * self.hp.audio.hop_length | |
audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length] | |
audio = audio + (1/32768) * torch.randn_like(audio) | |
return mel, audio | |