import torch
import torch.nn.functional as F
from modules.tts.fs2_orig import FastSpeech2Orig
from tasks.tts.dataset_utils import FastSpeechDataset
from tasks.tts.fs import FastSpeechTask
from utils.commons.dataset_utils import collate_1d, collate_2d
from utils.commons.hparams import hparams
from utils.plot.plot import spec_to_figure
import numpy as np


class FastSpeech2OrigDataset(FastSpeechDataset):
    def __init__(self, prefix, shuffle=False, items=None, data_dir=None):
        super().__init__(prefix, shuffle, items, data_dir)
        self.pitch_type = hparams.get('pitch_type')

    def __getitem__(self, index):
        sample = super().__getitem__(index)
        item = self._get_item(index)
        hparams = self.hparams
        mel = sample['mel']
        T = mel.shape[0]
        sample['energy'] = (mel.exp() ** 2).sum(-1).sqrt()
        if hparams['use_pitch_embed'] and self.pitch_type == 'cwt':
            cwt_spec = torch.Tensor(item['cwt_spec'])[:T]
            f0_mean = item.get('f0_mean', item.get('cwt_mean'))
            f0_std = item.get('f0_std', item.get('cwt_std'))
            sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
        return sample

    def collater(self, samples):
        if len(samples) == 0:
            return {}
        batch = super().collater(samples)
        if hparams['use_pitch_embed']:
            energy = collate_1d([s['energy'] for s in samples], 0.0)
        else:
            energy = None
        batch.update({'energy': energy})
        if self.pitch_type == 'cwt':
            cwt_spec = collate_2d([s['cwt_spec'] for s in samples])
            f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
            f0_std = torch.Tensor([s['f0_std'] for s in samples])
            batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
        return batch


class FastSpeech2OrigTask(FastSpeechTask):
    def __init__(self):
        super(FastSpeech2OrigTask, self).__init__()
        self.dataset_cls = FastSpeech2OrigDataset

    def build_tts_model(self):
        dict_size = len(self.token_encoder)
        self.model = FastSpeech2Orig(dict_size, hparams)

    def run_model(self, sample, infer=False, *args, **kwargs):
        txt_tokens = sample['txt_tokens']  # [B, T_t]
        spk_embed = sample.get('spk_embed')
        spk_id = sample.get('spk_ids')
        if not infer:
            target = sample['mels']  # [B, T_s, 80]
            mel2ph = sample['mel2ph']  # [B, T_s]
            f0 = sample.get('f0')
            uv = sample.get('uv')
            energy = sample.get('energy')
            output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
                                f0=f0, uv=uv, energy=energy, infer=False)
            losses = {}
            self.add_mel_loss(output['mel_out'], target, losses)
            self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
            if hparams['use_pitch_embed']:
                self.add_pitch_loss(output, sample, losses)
            if hparams['use_energy_embed']:
                self.add_energy_loss(output, sample, losses)
            return losses, output
        else:
            mel2ph, uv, f0, energy = None, None, None, None
            use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
            use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
            use_gt_energy = kwargs.get('infer_use_gt_energy', hparams['use_gt_energy'])
            if use_gt_dur:
                mel2ph = sample['mel2ph']
            if use_gt_f0:
                f0 = sample['f0']
                uv = sample['uv']
            if use_gt_energy:
                energy = sample['energy']
            output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
                                f0=f0, uv=uv, energy=energy, infer=True)
            return output

    def save_valid_result(self, sample, batch_idx, model_out):
        super(FastSpeech2OrigTask, self).save_valid_result(sample, batch_idx, model_out)
        self.plot_cwt(batch_idx, model_out['cwt'], sample['cwt_spec'])

    def plot_cwt(self, batch_idx, cwt_out, cwt_gt=None):
        if len(cwt_out.shape) == 3:
            cwt_out = cwt_out[0]
        if isinstance(cwt_out, torch.Tensor):
            cwt_out = cwt_out.cpu().numpy()
        if cwt_gt is not None:
            if len(cwt_gt.shape) == 3:
                cwt_gt = cwt_gt[0]
            if isinstance(cwt_gt, torch.Tensor):
                cwt_gt = cwt_gt.cpu().numpy()
            cwt_out = np.concatenate([cwt_out, cwt_gt], -1)
        name = f'cwt_val_{batch_idx}'
        self.logger.add_figure(name, spec_to_figure(cwt_out), self.global_step)

    def add_pitch_loss(self, output, sample, losses):
        if hparams['pitch_type'] == 'cwt':
            cwt_spec = sample[f'cwt_spec']
            f0_mean = sample['f0_mean']
            uv = sample['uv']
            mel2ph = sample['mel2ph']
            f0_std = sample['f0_std']
            cwt_pred = output['cwt'][:, :, :10]
            f0_mean_pred = output['f0_mean']
            f0_std_pred = output['f0_std']
            nonpadding = (mel2ph != 0).float()
            losses['C'] = F.l1_loss(cwt_pred, cwt_spec) * hparams['lambda_f0']
            if hparams['use_uv']:
                assert output['cwt'].shape[-1] == 11
                uv_pred = output['cwt'][:, :, -1]
                losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none')
                                * nonpadding).sum() / nonpadding.sum() * hparams['lambda_uv']
            losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0']
            losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0']
        else:
            super(FastSpeech2OrigTask, self).add_pitch_loss(output, sample, losses)

    def add_energy_loss(self, output, sample, losses):
        energy_pred, energy = output['energy_pred'], sample['energy']
        nonpadding = (energy != 0).float()
        loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
        loss = loss * hparams['lambda_energy']
        losses['e'] = loss