#!/usr/bin/env python3

# Copyright 2020 Songxiang Liu
# Apache 2.0

from typing import List

import torch
import torch.nn.functional as F

import numpy as np

from .utils.abs_model import AbsMelDecoder
from .rnn_decoder_mol import Decoder
from .utils.cnn_postnet import Postnet
from .utils.vc_utils import get_mask_from_lengths

from utils.load_yaml import HpsYaml

class MelDecoderMOLv2(AbsMelDecoder):
    """Use an encoder to preprocess ppg."""
    def __init__(
        self,
        num_speakers: int,
        spk_embed_dim: int,
        bottle_neck_feature_dim: int,
        encoder_dim: int = 256,
        encoder_downsample_rates: List = [2, 2],
        attention_rnn_dim: int = 512,
        decoder_rnn_dim: int = 512,
        num_decoder_rnn_layer: int = 1,
        concat_context_to_last: bool = True,
        prenet_dims: List = [256, 128],
        num_mixtures: int = 5,
        frames_per_step: int = 2,
        mask_padding: bool = True,
    ):
        super().__init__()
        
        self.mask_padding = mask_padding
        self.bottle_neck_feature_dim = bottle_neck_feature_dim
        self.num_mels = 80
        self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
        self.frames_per_step = frames_per_step
        self.use_spk_dvec = True

        input_dim = bottle_neck_feature_dim
        
        # Downsampling convolution
        self.bnf_prenet = torch.nn.Sequential(
            torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
            torch.nn.LeakyReLU(0.1),

            torch.nn.InstanceNorm1d(encoder_dim, affine=False),
            torch.nn.Conv1d(
                encoder_dim, encoder_dim, 
                kernel_size=2*encoder_downsample_rates[0], 
                stride=encoder_downsample_rates[0], 
                padding=encoder_downsample_rates[0]//2,
            ),
            torch.nn.LeakyReLU(0.1),
            
            torch.nn.InstanceNorm1d(encoder_dim, affine=False),
            torch.nn.Conv1d(
                encoder_dim, encoder_dim, 
                kernel_size=2*encoder_downsample_rates[1], 
                stride=encoder_downsample_rates[1], 
                padding=encoder_downsample_rates[1]//2,
            ),
            torch.nn.LeakyReLU(0.1),

            torch.nn.InstanceNorm1d(encoder_dim, affine=False),
        )
        decoder_enc_dim = encoder_dim
        self.pitch_convs = torch.nn.Sequential(
            torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
            torch.nn.LeakyReLU(0.1),

            torch.nn.InstanceNorm1d(encoder_dim, affine=False),
            torch.nn.Conv1d(
                encoder_dim, encoder_dim, 
                kernel_size=2*encoder_downsample_rates[0], 
                stride=encoder_downsample_rates[0], 
                padding=encoder_downsample_rates[0]//2,
            ),
            torch.nn.LeakyReLU(0.1),
            
            torch.nn.InstanceNorm1d(encoder_dim, affine=False),
            torch.nn.Conv1d(
                encoder_dim, encoder_dim, 
                kernel_size=2*encoder_downsample_rates[1], 
                stride=encoder_downsample_rates[1], 
                padding=encoder_downsample_rates[1]//2,
            ),
            torch.nn.LeakyReLU(0.1),

            torch.nn.InstanceNorm1d(encoder_dim, affine=False),
        )
        
        self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)

        # Decoder
        self.decoder = Decoder(
            enc_dim=decoder_enc_dim,
            num_mels=self.num_mels,
            frames_per_step=frames_per_step,
            attention_rnn_dim=attention_rnn_dim,
            decoder_rnn_dim=decoder_rnn_dim,
            num_decoder_rnn_layer=num_decoder_rnn_layer,
            prenet_dims=prenet_dims,
            num_mixtures=num_mixtures,
            use_stop_tokens=True,
            concat_context_to_last=concat_context_to_last,
            encoder_down_factor=self.encoder_down_factor,
        )

        # Mel-Spec Postnet: some residual CNN layers
        self.postnet = Postnet()
    
    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
            mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
        return outputs

    def forward(
        self,
        bottle_neck_features: torch.Tensor,
        feature_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        logf0_uv: torch.Tensor = None,
        spembs: torch.Tensor = None,
        output_att_ws: bool = False,
    ):
        decoder_inputs = self.bnf_prenet(
            bottle_neck_features.transpose(1, 2)
        ).transpose(1, 2)
        logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
        decoder_inputs = decoder_inputs + logf0_uv
            
        assert spembs is not None
        spk_embeds = F.normalize(
            spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
        decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
        decoder_inputs = self.reduce_proj(decoder_inputs)
        
        # (B, num_mels, T_dec)
        T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
        mel_outputs, predicted_stop, alignments = self.decoder(
            decoder_inputs, speech, T_dec)
        ## Post-processing
        mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
        if output_att_ws: 
            return self.parse_output(
                [mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
        else:
            return self.parse_output(
                [mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)

        # return mel_outputs, mel_outputs_postnet

    def inference(
        self,
        bottle_neck_features: torch.Tensor,
        logf0_uv: torch.Tensor = None,
        spembs: torch.Tensor = None,
    ):
        decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
        logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
        decoder_inputs = decoder_inputs + logf0_uv

        assert spembs is not None
        spk_embeds = F.normalize(
            spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
        bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
        bottle_neck_features = self.reduce_proj(bottle_neck_features)

        ## Decoder
        if bottle_neck_features.size(0) > 1:
            mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
        else:
            mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
        ## Post-processing
        mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
        # outputs = mel_outputs_postnet[0]
        
        return mel_outputs[0], mel_outputs_postnet[0], alignments[0]

def load_model(model_file, device=None):
    # search a config file
    model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
    if len(model_config_fpaths) == 0:
        raise "No model yaml config found for convertor"
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_config = HpsYaml(model_config_fpaths[0])
    ppg2mel_model = MelDecoderMOLv2(
        **model_config["model"]
    ).to(device)
    ckpt = torch.load(model_file, map_location=device)
    ppg2mel_model.load_state_dict(ckpt["model"])
    ppg2mel_model.eval()
    return ppg2mel_model