Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from modules.tts.fs2_orig import FastSpeech2Orig | |
| from tasks.tts.dataset_utils import FastSpeechDataset | |
| from tasks.tts.fs import FastSpeechTask | |
| from utils.commons.dataset_utils import collate_1d, collate_2d | |
| from utils.commons.hparams import hparams | |
| from utils.plot.plot import spec_to_figure | |
| import numpy as np | |
| class FastSpeech2OrigDataset(FastSpeechDataset): | |
| def __init__(self, prefix, shuffle=False, items=None, data_dir=None): | |
| super().__init__(prefix, shuffle, items, data_dir) | |
| self.pitch_type = hparams.get('pitch_type') | |
| def __getitem__(self, index): | |
| sample = super().__getitem__(index) | |
| item = self._get_item(index) | |
| hparams = self.hparams | |
| mel = sample['mel'] | |
| T = mel.shape[0] | |
| sample['energy'] = (mel.exp() ** 2).sum(-1).sqrt() | |
| if hparams['use_pitch_embed'] and self.pitch_type == 'cwt': | |
| cwt_spec = torch.Tensor(item['cwt_spec'])[:T] | |
| f0_mean = item.get('f0_mean', item.get('cwt_mean')) | |
| f0_std = item.get('f0_std', item.get('cwt_std')) | |
| sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std}) | |
| return sample | |
| def collater(self, samples): | |
| if len(samples) == 0: | |
| return {} | |
| batch = super().collater(samples) | |
| if hparams['use_pitch_embed']: | |
| energy = collate_1d([s['energy'] for s in samples], 0.0) | |
| else: | |
| energy = None | |
| batch.update({'energy': energy}) | |
| if self.pitch_type == 'cwt': | |
| cwt_spec = collate_2d([s['cwt_spec'] for s in samples]) | |
| f0_mean = torch.Tensor([s['f0_mean'] for s in samples]) | |
| f0_std = torch.Tensor([s['f0_std'] for s in samples]) | |
| batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std}) | |
| return batch | |
| class FastSpeech2OrigTask(FastSpeechTask): | |
| def __init__(self): | |
| super(FastSpeech2OrigTask, self).__init__() | |
| self.dataset_cls = FastSpeech2OrigDataset | |
| def build_tts_model(self): | |
| dict_size = len(self.token_encoder) | |
| self.model = FastSpeech2Orig(dict_size, hparams) | |
| def run_model(self, sample, infer=False, *args, **kwargs): | |
| txt_tokens = sample['txt_tokens'] # [B, T_t] | |
| spk_embed = sample.get('spk_embed') | |
| spk_id = sample.get('spk_ids') | |
| if not infer: | |
| target = sample['mels'] # [B, T_s, 80] | |
| mel2ph = sample['mel2ph'] # [B, T_s] | |
| f0 = sample.get('f0') | |
| uv = sample.get('uv') | |
| energy = sample.get('energy') | |
| output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, | |
| f0=f0, uv=uv, energy=energy, infer=False) | |
| losses = {} | |
| self.add_mel_loss(output['mel_out'], target, losses) | |
| self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) | |
| if hparams['use_pitch_embed']: | |
| self.add_pitch_loss(output, sample, losses) | |
| if hparams['use_energy_embed']: | |
| self.add_energy_loss(output, sample, losses) | |
| return losses, output | |
| else: | |
| mel2ph, uv, f0, energy = None, None, None, None | |
| use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur']) | |
| use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0']) | |
| use_gt_energy = kwargs.get('infer_use_gt_energy', hparams['use_gt_energy']) | |
| if use_gt_dur: | |
| mel2ph = sample['mel2ph'] | |
| if use_gt_f0: | |
| f0 = sample['f0'] | |
| uv = sample['uv'] | |
| if use_gt_energy: | |
| energy = sample['energy'] | |
| output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, | |
| f0=f0, uv=uv, energy=energy, infer=True) | |
| return output | |
| def save_valid_result(self, sample, batch_idx, model_out): | |
| super(FastSpeech2OrigTask, self).save_valid_result(sample, batch_idx, model_out) | |
| self.plot_cwt(batch_idx, model_out['cwt'], sample['cwt_spec']) | |
| def plot_cwt(self, batch_idx, cwt_out, cwt_gt=None): | |
| if len(cwt_out.shape) == 3: | |
| cwt_out = cwt_out[0] | |
| if isinstance(cwt_out, torch.Tensor): | |
| cwt_out = cwt_out.cpu().numpy() | |
| if cwt_gt is not None: | |
| if len(cwt_gt.shape) == 3: | |
| cwt_gt = cwt_gt[0] | |
| if isinstance(cwt_gt, torch.Tensor): | |
| cwt_gt = cwt_gt.cpu().numpy() | |
| cwt_out = np.concatenate([cwt_out, cwt_gt], -1) | |
| name = f'cwt_val_{batch_idx}' | |
| self.logger.add_figure(name, spec_to_figure(cwt_out), self.global_step) | |
| def add_pitch_loss(self, output, sample, losses): | |
| if hparams['pitch_type'] == 'cwt': | |
| cwt_spec = sample[f'cwt_spec'] | |
| f0_mean = sample['f0_mean'] | |
| uv = sample['uv'] | |
| mel2ph = sample['mel2ph'] | |
| f0_std = sample['f0_std'] | |
| cwt_pred = output['cwt'][:, :, :10] | |
| f0_mean_pred = output['f0_mean'] | |
| f0_std_pred = output['f0_std'] | |
| nonpadding = (mel2ph != 0).float() | |
| losses['C'] = F.l1_loss(cwt_pred, cwt_spec) * hparams['lambda_f0'] | |
| if hparams['use_uv']: | |
| assert output['cwt'].shape[-1] == 11 | |
| uv_pred = output['cwt'][:, :, -1] | |
| losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none') | |
| * nonpadding).sum() / nonpadding.sum() * hparams['lambda_uv'] | |
| losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0'] | |
| losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0'] | |
| else: | |
| super(FastSpeech2OrigTask, self).add_pitch_loss(output, sample, losses) | |
| def add_energy_loss(self, output, sample, losses): | |
| energy_pred, energy = output['energy_pred'], sample['energy'] | |
| nonpadding = (energy != 0).float() | |
| loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum() | |
| loss = loss * hparams['lambda_energy'] | |
| losses['e'] = loss | |