Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import json | |
| import numpy as np | |
| from text import text_to_sequence | |
| from text.text_token_collation import phoneIDCollation | |
| from models.tts.base.tts_dataset import ( | |
| TTSDataset, | |
| TTSCollator, | |
| TTSTestDataset, | |
| TTSTestCollator | |
| ) | |
| class VITSDataset(TTSDataset): | |
| def __init__(self, cfg, dataset, is_valid): | |
| super().__init__(cfg, dataset, is_valid=is_valid) | |
| def __getitem__(self, index): | |
| single_feature = super().__getitem__(index) | |
| return single_feature | |
| def __len__(self): | |
| return super().__len__() | |
| def get_metadata(self): | |
| metadata_filter = [] | |
| with open(self.metafile_path, "r", encoding="utf-8") as f: | |
| metadata = json.load(f) | |
| for utt_info in metadata: | |
| duration = utt_info['Duration'] | |
| frame_len = duration * self.cfg.preprocess.sample_rate // self.cfg.preprocess.hop_size | |
| if frame_len < self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size: | |
| continue | |
| metadata_filter.append(utt_info) | |
| return metadata_filter | |
| class VITSCollator(TTSCollator): | |
| """Zero-pads model inputs and targets based on number of frames per step""" | |
| def __init__(self, cfg): | |
| super().__init__(cfg) | |
| def __call__(self, batch): | |
| parsed_batch_features = super().__call__(batch) | |
| return parsed_batch_features | |
| class VITSTestDataset(TTSTestDataset): | |
| def __init__(self, args, cfg): | |
| super().__init__(args, cfg) | |
| if cfg.preprocess.use_spkid: | |
| processed_data_dir = os.path.join(cfg.preprocess.processed_dir, args.dataset) | |
| spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id) | |
| with open(spk2id_path, "r") as f: | |
| self.spk2id = json.load(f) | |
| utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk) | |
| self.utt2spk = dict() | |
| with open(utt2spk_path, "r") as f: | |
| for line in f.readlines(): | |
| utt, spk = line.strip().split('\t') | |
| self.utt2spk[utt] = spk | |
| if cfg.preprocess.use_text or cfg.preprocess.use_phone: | |
| self.utt2seq = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| if cfg.preprocess.use_text: | |
| text = utt_info["Text"] | |
| sequence = text_to_sequence(text, cfg.preprocess.text_cleaners) | |
| elif cfg.preprocess.use_phone: | |
| # load phoneme squence from phone file | |
| phone_path = os.path.join(processed_data_dir, | |
| cfg.preprocess.phone_dir, | |
| uid+'.phone' | |
| ) | |
| with open(phone_path, 'r') as fin: | |
| phones = fin.readlines() | |
| assert len(phones) == 1 | |
| phones = phones[0].strip() | |
| phones_seq = phones.split(' ') | |
| phon_id_collator = phoneIDCollation(cfg, dataset=dataset) | |
| sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq) | |
| self.utt2seq[utt] = sequence | |
| def __getitem__(self, index): | |
| utt_info = self.metadata[index] | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| single_feature = dict() | |
| if self.cfg.preprocess.use_spkid: | |
| single_feature["spk_id"] = np.array( | |
| [self.spk2id[self.utt2spk[utt]]], dtype=np.int32 | |
| ) | |
| if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text: | |
| single_feature["phone_seq"] = np.array(self.utt2seq[utt]) | |
| single_feature["phone_len"] = len(self.utt2seq[utt]) | |
| return single_feature | |
| def get_metadata(self): | |
| with open(self.metafile_path, "r", encoding="utf-8") as f: | |
| metadata = json.load(f) | |
| return metadata | |
| def __len__(self): | |
| return len(self.metadata) | |
| class VITSTestCollator(TTSTestCollator): | |
| """Zero-pads model inputs and targets based on number of frames per step""" | |
| def __init__(self, cfg): | |
| self.cfg = cfg | |
| def __call__(self, batch): | |
| return super().__call__(batch) |