import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present

URLS = {
    "hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt",
    "hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt",
}

class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x,
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        #print(type(x))
        #print(x.shape)
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states

        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

class AcousticModel(nn.Module):
    def __init__(self, discrete: bool = False, upsample: bool = True, use_custom_lstm=False):
        super().__init__()
        # self.spk_projection = nn.Linear(512+512, 512)
        self.encoder = Encoder(discrete, upsample)
        self.decoder = Decoder(use_custom_lstm=use_custom_lstm)

    def forward(self, x: torch.Tensor, spk_embs, mels: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
        concat_x = torch.cat([x, exp_spk_embs], dim=-1)
        # x = self.spk_projection(concat_x)
        return self.decoder(concat_x, mels)

    #def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
    #    x = self.encoder(x)
    #    return self.decoder(x, mels)

    def forward_test(self, x, spk_embs, mels):
      print('x shape', x.shape)
      print('se shape', spk_embs.shape)
      print('mels shape', mels.shape)
      x = self.encoder(x)
      print('x_enc shape', x.shape)
      return

    @torch.inference_mode()
    def generate(self, x: torch.Tensor, spk_embs) -> torch.Tensor:
        x = self.encoder(x)
        exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
        concat_x = torch.cat([x, exp_spk_embs], dim=-1)
        # x = self.spk_projection(concat_x)
        return self.decoder.generate(concat_x)


class Encoder(nn.Module):
    def __init__(self, discrete: bool = False, upsample: bool = True):
        super().__init__()
        self.embedding = nn.Embedding(100 + 1, 256) if discrete else None
        self.prenet = PreNet(256, 256, 256)
        self.convs = nn.Sequential(
            nn.Conv1d(256, 512, 5, 1, 2),
            nn.ReLU(),
            nn.InstanceNorm1d(512),
            nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
            nn.Conv1d(512, 512, 5, 1, 2),
            nn.ReLU(),
            nn.InstanceNorm1d(512),
            nn.Conv1d(512, 512, 5, 1, 2),
            nn.ReLU(),
            nn.InstanceNorm1d(512),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.embedding is not None:
            x = self.embedding(x)
        x = self.prenet(x)
        x = self.convs(x.transpose(1, 2))
        return x.transpose(1, 2)


class Decoder(nn.Module):
    def __init__(self, use_custom_lstm=False):
        super().__init__()
        self.use_custom_lstm = use_custom_lstm
        self.prenet = PreNet(128, 256, 256)
        self.prenet = PreNet(128, 256, 256)
        if use_custom_lstm:
          self.lstm1 = CustomLSTM(1024 + 256, 768)
          self.lstm2 = CustomLSTM(768, 768)
          self.lstm3 = CustomLSTM(768, 768)
        else:
          self.lstm1 = nn.LSTM(1024 + 256, 768)
          self.lstm2 = nn.LSTM(768, 768)
          self.lstm3 = nn.LSTM(768, 768)
        self.proj = nn.Linear(768, 128, bias=False)

    def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
        mels = self.prenet(mels)
        x, _ = self.lstm1(torch.cat((x, mels), dim=-1))
        res = x
        x, _ = self.lstm2(x)
        x = res + x
        res = x
        x, _ = self.lstm3(x)
        x = res + x
        return self.proj(x)

    @torch.inference_mode()
    def generate(self, xs: torch.Tensor) -> torch.Tensor:
        m = torch.zeros(xs.size(0), 128, device=xs.device)
        if not self.use_custom_lstm:
          h1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
          c1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
          h2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
          c2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
          h3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
          c3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
        else:
          h1 = torch.zeros(xs.size(0), 768, device=xs.device)
          c1 = torch.zeros(xs.size(0), 768, device=xs.device)
          h2 = torch.zeros(xs.size(0), 768, device=xs.device)
          c2 = torch.zeros(xs.size(0), 768, device=xs.device)
          h3 = torch.zeros(xs.size(0), 768, device=xs.device)
          c3 = torch.zeros(xs.size(0), 768, device=xs.device)

        mel = []
        for x in torch.unbind(xs, dim=1):
            m = self.prenet(m)
            x = torch.cat((x, m), dim=1).unsqueeze(1)
            x1, (h1, c1) = self.lstm1(x, (h1, c1))
            x2, (h2, c2) = self.lstm2(x1, (h2, c2))
            x = x1 + x2
            x3, (h3, c3) = self.lstm3(x, (h3, c3))
            x = x + x3
            m = self.proj(x).squeeze(1)
            mel.append(m)
        return torch.stack(mel, dim=1)


class PreNet(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def _acoustic(
    name: str,
    discrete: bool,
    upsample: bool,
    pretrained: bool = True,
    progress: bool = True,
) -> AcousticModel:
    acoustic = AcousticModel(discrete, upsample)
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
        consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
        acoustic.load_state_dict(checkpoint["acoustic-model"])
        acoustic.eval()
    return acoustic


def hubert_discrete(
    pretrained: bool = True,
    progress: bool = True,
) -> AcousticModel:
    r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
    Args:
        pretrained (bool): load pretrained weights into the model
        progress (bool): show progress bar when downloading model
    """
    return _acoustic(
        "hubert-discrete",
        discrete=True,
        upsample=True,
        pretrained=pretrained,
        progress=progress,
    )


def hubert_soft(
    pretrained: bool = True,
    progress: bool = True,
) -> AcousticModel:
    r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
    Args:
        pretrained (bool): load pretrained weights into the model
        progress (bool): show progress bar when downloading model
    """
    return _acoustic(
        "hubert-soft",
        discrete=False,
        upsample=True,
        pretrained=pretrained,
        progress=progress,
    )