Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from utils.data_utils import * | |
| from models.base.base_dataset import ( | |
| BaseCollator, | |
| BaseDataset, | |
| BaseTestDataset, | |
| BaseTestCollator, | |
| ) | |
| import librosa | |
| from transformers import AutoTokenizer | |
| class AudioLDMDataset(BaseDataset): | |
| def __init__(self, cfg, dataset, is_valid=False): | |
| BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) | |
| self.cfg = cfg | |
| # utt2melspec | |
| if cfg.preprocess.use_melspec: | |
| self.utt2melspec_path = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| self.utt2melspec_path[utt] = os.path.join( | |
| cfg.preprocess.processed_dir, | |
| dataset, | |
| cfg.preprocess.melspec_dir, | |
| uid + ".npy", | |
| ) | |
| # utt2wav | |
| if cfg.preprocess.use_wav: | |
| self.utt2wav_path = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| self.utt2wav_path[utt] = os.path.join( | |
| cfg.preprocess.processed_dir, | |
| dataset, | |
| cfg.preprocess.wav_dir, | |
| uid + ".wav", | |
| ) | |
| # utt2caption | |
| if cfg.preprocess.use_caption: | |
| self.utt2caption = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| self.utt2caption[utt] = utt_info["Caption"] | |
| def __getitem__(self, index): | |
| # melspec: (n_mels, T) | |
| # wav: (T,) | |
| single_feature = BaseDataset.__getitem__(self, index) | |
| utt_info = self.metadata[index] | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| if self.cfg.preprocess.use_melspec: | |
| single_feature["melspec"] = np.load(self.utt2melspec_path[utt]) | |
| if self.cfg.preprocess.use_wav: | |
| wav, sr = librosa.load( | |
| self.utt2wav_path[utt], sr=16000 | |
| ) # hard coding for 16KHz... | |
| single_feature["wav"] = wav | |
| if self.cfg.preprocess.use_caption: | |
| cond_mask = np.random.choice( | |
| [1, 0], | |
| p=[ | |
| self.cfg.preprocess.cond_mask_prob, | |
| 1 - self.cfg.preprocess.cond_mask_prob, | |
| ], | |
| ) # (0.1, 0.9) | |
| if cond_mask: | |
| single_feature["caption"] = "" | |
| else: | |
| single_feature["caption"] = self.utt2caption[utt] | |
| return single_feature | |
| def __len__(self): | |
| return len(self.metadata) | |
| class AudioLDMCollator(BaseCollator): | |
| def __init__(self, cfg): | |
| BaseCollator.__init__(self, cfg) | |
| self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) | |
| def __call__(self, batch): | |
| # mel: (B, n_mels, T) | |
| # wav (option): (B, T) | |
| # text_input_ids: (B, L) | |
| # text_attention_mask: (B, L) | |
| packed_batch_features = dict() | |
| for key in batch[0].keys(): | |
| if key == "melspec": | |
| packed_batch_features["melspec"] = torch.from_numpy( | |
| np.array([b["melspec"][:, :624] for b in batch]) | |
| ) | |
| if key == "wav": | |
| values = [torch.from_numpy(b[key]) for b in batch] | |
| packed_batch_features[key] = pad_sequence( | |
| values, batch_first=True, padding_value=0 | |
| ) | |
| if key == "caption": | |
| captions = [b[key] for b in batch] | |
| text_input = self.tokenizer( | |
| captions, return_tensors="pt", truncation=True, padding="longest" | |
| ) | |
| text_input_ids = text_input["input_ids"] | |
| text_attention_mask = text_input["attention_mask"] | |
| packed_batch_features["text_input_ids"] = text_input_ids | |
| packed_batch_features["text_attention_mask"] = text_attention_mask | |
| return packed_batch_features | |
| class AudioLDMTestDataset(BaseTestDataset): | |
| ... | |
| class AudioLDMTestCollator(BaseTestCollator): | |
| ... | |