from typing import List, Tuple, Union

import torch
import torch.nn as nn  # pylint: disable=consider-using-from-import
import torch.nn.functional as F

from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
from TTS.tts.layers.delightful_tts.networks import STL


def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
    batch_size = lengths.shape[0]
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
    return mask


def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
    return torch.ceil(lens / stride).int()


class ReferenceEncoder(nn.Module):
    """
    Referance encoder for utterance and phoneme prosody encoders. Reference encoder
    made up of convolution and RNN layers.

    Args:
        num_mels (int): Number of mel frames to produce.
        ref_enc_filters (list[int]): List of channel sizes for encoder layers.
        ref_enc_size (int): Size of the kernel for the conv layers.
        ref_enc_strides (List[int]): List of strides to use for conv layers.
        ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.

    Inputs: inputs, mask
        - **inputs** (batch, dim, time): Tensor containing mel vector
        - **lengths** (batch): Tensor containing the mel lengths.
    Returns:
        - **outputs** (batch, time, dim): Tensor produced by Reference Encoder.
    """

    def __init__(
        self,
        num_mels: int,
        ref_enc_filters: List[Union[int, int, int, int, int, int]],
        ref_enc_size: int,
        ref_enc_strides: List[Union[int, int, int, int, int]],
        ref_enc_gru_size: int,
    ):
        super().__init__()

        n_mel_channels = num_mels
        self.n_mel_channels = n_mel_channels
        K = len(ref_enc_filters)
        filters = [self.n_mel_channels] + ref_enc_filters
        strides = [1] + ref_enc_strides
        # Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf
        convs = [
            CoordConv1d(
                in_channels=filters[0],
                out_channels=filters[0 + 1],
                kernel_size=ref_enc_size,
                stride=strides[0],
                padding=ref_enc_size // 2,
                with_r=True,
            )
        ]
        convs2 = [
            nn.Conv1d(
                in_channels=filters[i],
                out_channels=filters[i + 1],
                kernel_size=ref_enc_size,
                stride=strides[i],
                padding=ref_enc_size // 2,
            )
            for i in range(1, K)
        ]
        convs.extend(convs2)
        self.convs = nn.ModuleList(convs)

        self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)])

        self.gru = nn.GRU(
            input_size=ref_enc_filters[-1],
            hidden_size=ref_enc_gru_size,
            batch_first=True,
        )

    def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        inputs --- [N,  n_mels, timesteps]
        outputs --- [N, E//2]
        """

        mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
        x = x.masked_fill(mel_masks, 0)
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x)
            x = F.leaky_relu(x, 0.3)  # [N, 128, Ty//2^K, n_mels//2^K]
            x = norm(x)

        for _ in range(2):
            mel_lens = stride_lens(mel_lens)

        mel_masks = get_mask_from_lengths(mel_lens)

        x = x.masked_fill(mel_masks.unsqueeze(1), 0)
        x = x.permute((0, 2, 1))
        x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False)

        self.gru.flatten_parameters()
        x, memory = self.gru(x)  # memory --- [N, Ty, E//2], out --- [1, N, E//2]
        x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)

        return x, memory, mel_masks

    def calculate_channels(  # pylint: disable=no-self-use
        self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
    ) -> int:
        for _ in range(n_convs):
            L = (L - kernel_size + 2 * pad) // stride + 1
        return L


class UtteranceLevelProsodyEncoder(nn.Module):
    def __init__(
        self,
        num_mels: int,
        ref_enc_filters: List[Union[int, int, int, int, int, int]],
        ref_enc_size: int,
        ref_enc_strides: List[Union[int, int, int, int, int]],
        ref_enc_gru_size: int,
        dropout: float,
        n_hidden: int,
        bottleneck_size_u: int,
        token_num: int,
    ):
        """
        Encoder to extract prosody from utterance. it is made up of a reference encoder
        with a couple of linear layers and style token layer with dropout.

        Args:
            num_mels (int): Number of mel frames to produce.
            ref_enc_filters (list[int]): List of channel sizes for ref encoder layers.
            ref_enc_size (int): Size of the kernel for the ref encoder conv layers.
            ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers.
            ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
            dropout (float): Probability of dropout.
            n_hidden (int): Size of hidden layers.
            bottleneck_size_u (int): Size of the bottle neck layer.

        Inputs: inputs, mask
            - **inputs** (batch, dim, time): Tensor containing mel vector
            - **lengths** (batch): Tensor containing the mel lengths.
        Returns:
            - **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder.
        """
        super().__init__()

        self.E = n_hidden
        self.d_q = self.d_k = n_hidden
        bottleneck_size = bottleneck_size_u

        self.encoder = ReferenceEncoder(
            ref_enc_filters=ref_enc_filters,
            ref_enc_gru_size=ref_enc_gru_size,
            ref_enc_size=ref_enc_size,
            ref_enc_strides=ref_enc_strides,
            num_mels=num_mels,
        )
        self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2)
        self.stl = STL(n_hidden=n_hidden, token_num=token_num)
        self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor:
        """
        Shapes:
            mels: :math: `[B, C, T]`
            mel_lens: :math: `[B]`

        out --- [N, seq_len, E]
        """
        _, embedded_prosody, _ = self.encoder(mels, mel_lens)

        # Bottleneck
        embedded_prosody = self.encoder_prj(embedded_prosody)

        # Style Token
        out = self.encoder_bottleneck(self.stl(embedded_prosody))
        out = self.dropout(out)

        out = out.view((-1, 1, out.shape[3]))
        return out


class PhonemeLevelProsodyEncoder(nn.Module):
    def __init__(
        self,
        num_mels: int,
        ref_enc_filters: List[Union[int, int, int, int, int, int]],
        ref_enc_size: int,
        ref_enc_strides: List[Union[int, int, int, int, int]],
        ref_enc_gru_size: int,
        dropout: float,
        n_hidden: int,
        n_heads: int,
        bottleneck_size_p: int,
    ):
        super().__init__()

        self.E = n_hidden
        self.d_q = self.d_k = n_hidden
        bottleneck_size = bottleneck_size_p

        self.encoder = ReferenceEncoder(
            ref_enc_filters=ref_enc_filters,
            ref_enc_gru_size=ref_enc_gru_size,
            ref_enc_size=ref_enc_size,
            ref_enc_strides=ref_enc_strides,
            num_mels=num_mels,
        )
        self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden)
        self.attention = ConformerMultiHeadedSelfAttention(
            d_model=n_hidden,
            num_heads=n_heads,
            dropout_p=dropout,
        )
        self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size)

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor,
        mels: torch.Tensor,
        mel_lens: torch.Tensor,
        encoding: torch.Tensor,
    ) -> torch.Tensor:
        """
        x --- [N, seq_len, encoder_embedding_dim]
        mels --- [N, Ty/r, n_mels*r], r=1
        out --- [N, seq_len, bottleneck_size]
        attn --- [N, seq_len, ref_len], Ty/r = ref_len
        """
        embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens)

        # Bottleneck
        embedded_prosody = self.encoder_prj(embedded_prosody)

        attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1))
        x, _ = self.attention(
            query=x,
            key=embedded_prosody,
            value=embedded_prosody,
            mask=attn_mask,
            encoding=encoding,
        )
        x = self.encoder_bottleneck(x)
        x = x.masked_fill(src_mask.unsqueeze(-1), 0.0)
        return x