from torch.utils.data import DataLoader from vits.data_utils import DistributedBucketSampler from vits.data_utils import TextAudioSpeakerCollate from vits.data_utils import TextAudioSpeakerSet def create_dataloader_train(hps, n_gpus, rank): collate_fn = TextAudioSpeakerCollate() train_dataset = TextAudioSpeakerSet(hps.data.training_files, hps.data) train_sampler = DistributedBucketSampler( train_dataset, hps.train.batch_size, [150, 300, 450], num_replicas=n_gpus, rank=rank, shuffle=True) train_loader = DataLoader( train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler) return train_loader def create_dataloader_eval(hps): collate_fn = TextAudioSpeakerCollate() eval_dataset = TextAudioSpeakerSet(hps.data.validation_files, hps.data) eval_loader = DataLoader( eval_dataset, num_workers=2, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn) return eval_loader