Spaces:
Build error
Build error
| import torch | |
| import utils | |
| from .diff.diffusion import GaussianDiffusion | |
| from .diff.net import DiffNet | |
| from tasks.tts.fs2 import FastSpeech2Task | |
| from utils.hparams import hparams | |
| DIFF_DECODERS = { | |
| 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), | |
| } | |
| class DiffFsTask(FastSpeech2Task): | |
| def build_tts_model(self): | |
| mel_bins = hparams['audio_num_mel_bins'] | |
| self.model = GaussianDiffusion( | |
| phone_encoder=self.phone_encoder, | |
| out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), | |
| timesteps=hparams['timesteps'], | |
| loss_type=hparams['diff_loss_type'], | |
| spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], | |
| ) | |
| def run_model(self, model, sample, return_output=False, infer=False): | |
| txt_tokens = sample['txt_tokens'] # [B, T_t] | |
| target = sample['mels'] # [B, T_s, 80] | |
| mel2ph = sample['mel2ph'] # [B, T_s] | |
| f0 = sample['f0'] | |
| uv = sample['uv'] | |
| energy = sample['energy'] | |
| spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') | |
| if hparams['pitch_type'] == 'cwt': | |
| cwt_spec = sample[f'cwt_spec'] | |
| f0_mean = sample['f0_mean'] | |
| f0_std = sample['f0_std'] | |
| sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) | |
| output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, | |
| ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) | |
| losses = {} | |
| if 'diff_loss' in output: | |
| losses['mel'] = output['diff_loss'] | |
| self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) | |
| if hparams['use_pitch_embed']: | |
| self.add_pitch_loss(output, sample, losses) | |
| if hparams['use_energy_embed']: | |
| self.add_energy_loss(output['energy_pred'], energy, losses) | |
| if not return_output: | |
| return losses | |
| else: | |
| return losses, output | |
| def _training_step(self, sample, batch_idx, _): | |
| log_outputs = self.run_model(self.model, sample) | |
| total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
| log_outputs['batch_size'] = sample['txt_tokens'].size()[0] | |
| log_outputs['lr'] = self.scheduler.get_lr()[0] | |
| return total_loss, log_outputs | |
| def validation_step(self, sample, batch_idx): | |
| outputs = {} | |
| outputs['losses'] = {} | |
| outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) | |
| outputs['total_loss'] = sum(outputs['losses'].values()) | |
| outputs['nsamples'] = sample['nsamples'] | |
| outputs = utils.tensors_to_scalars(outputs) | |
| if batch_idx < hparams['num_valid_plots']: | |
| _, model_out = self.run_model(self.model, sample, return_output=True, infer=True) | |
| self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) | |
| return outputs | |
| def build_scheduler(self, optimizer): | |
| return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) | |
| def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx): | |
| if optimizer is None: | |
| return | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if self.scheduler is not None: | |
| self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) | |