import numpy as np
import torch
import torch.distributions as dist
from torch import nn

from modules.commons.conv import ConditionalConvBlocks
from modules.commons.normalizing_flow.res_flow import ResFlow
from modules.commons.wavenet import WN


class FVAEEncoder(nn.Module):
    def __init__(self, c_in, hidden_size, c_latent, kernel_size,
                 n_layers, c_cond=0, p_dropout=0, strides=[4], nn_type='wn'):
        super().__init__()
        self.strides = strides
        self.hidden_size = hidden_size
        if np.prod(strides) == 1:
            self.pre_net = nn.Conv1d(c_in, hidden_size, kernel_size=1)
        else:
            self.pre_net = nn.Sequential(*[
                nn.Conv1d(c_in, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2)
                if i == 0 else
                nn.Conv1d(hidden_size, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2)
                for i, s in enumerate(strides)
            ])
        if nn_type == 'wn':
            self.nn = WN(hidden_size, kernel_size, 1, n_layers, c_cond, p_dropout)
        elif nn_type == 'conv':
            self.nn = ConditionalConvBlocks(
                hidden_size, c_cond, hidden_size, None, kernel_size,
                layers_in_block=2, is_BTC=False, num_layers=n_layers)

        self.out_proj = nn.Conv1d(hidden_size, c_latent * 2, 1)
        self.latent_channels = c_latent

    def forward(self, x, nonpadding, cond):
        x = self.pre_net(x)
        nonpadding = nonpadding[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
        x = x * nonpadding
        x = self.nn(x, nonpadding=nonpadding, cond=cond) * nonpadding
        x = self.out_proj(x)
        m, logs = torch.split(x, self.latent_channels, dim=1)
        z = (m + torch.randn_like(m) * torch.exp(logs))
        return z, m, logs, nonpadding


class FVAEDecoder(nn.Module):
    def __init__(self, c_latent, hidden_size, out_channels, kernel_size,
                 n_layers, c_cond=0, p_dropout=0, strides=[4], nn_type='wn'):
        super().__init__()
        self.strides = strides
        self.hidden_size = hidden_size
        self.pre_net = nn.Sequential(*[
            nn.ConvTranspose1d(c_latent, hidden_size, kernel_size=s, stride=s)
            if i == 0 else
            nn.ConvTranspose1d(hidden_size, hidden_size, kernel_size=s, stride=s)
            for i, s in enumerate(strides)
        ])
        if nn_type == 'wn':
            self.nn = WN(hidden_size, kernel_size, 1, n_layers, c_cond, p_dropout)
        elif nn_type == 'conv':
            self.nn = ConditionalConvBlocks(
                hidden_size, c_cond, hidden_size, [1] * n_layers, kernel_size,
                layers_in_block=2, is_BTC=False)
        self.out_proj = nn.Conv1d(hidden_size, out_channels, 1)

    def forward(self, x, nonpadding, cond):
        x = self.pre_net(x)
        x = x * nonpadding
        x = self.nn(x, nonpadding=nonpadding, cond=cond) * nonpadding
        x = self.out_proj(x)
        return x


class FVAE(nn.Module):
    def __init__(self,
                 c_in_out, hidden_size, c_latent,
                 kernel_size, enc_n_layers, dec_n_layers, c_cond, strides,
                 use_prior_flow, flow_hidden=None, flow_kernel_size=None, flow_n_steps=None,
                 encoder_type='wn', decoder_type='wn'):
        super(FVAE, self).__init__()
        self.strides = strides
        self.hidden_size = hidden_size
        self.latent_size = c_latent
        self.use_prior_flow = use_prior_flow
        if np.prod(strides) == 1:
            self.g_pre_net = nn.Conv1d(c_cond, c_cond, kernel_size=1)
        else:
            self.g_pre_net = nn.Sequential(*[
                nn.Conv1d(c_cond, c_cond, kernel_size=s * 2, stride=s, padding=s // 2)
                for i, s in enumerate(strides)
            ])
        self.encoder = FVAEEncoder(c_in_out, hidden_size, c_latent, kernel_size,
                                   enc_n_layers, c_cond, strides=strides, nn_type=encoder_type)
        if use_prior_flow:
            self.prior_flow = ResFlow(
                c_latent, flow_hidden, flow_kernel_size, flow_n_steps, 4, c_cond=c_cond)
        self.decoder = FVAEDecoder(c_latent, hidden_size, c_in_out, kernel_size,
                                   dec_n_layers, c_cond, strides=strides, nn_type=decoder_type)
        self.prior_dist = dist.Normal(0, 1)

    def forward(self, x=None, nonpadding=None, cond=None, infer=False, noise_scale=1.0):
        """

        :param x: [B, C_in_out, T]
        :param nonpadding: [B, 1, T]
        :param cond: [B, C_g, T]
        :return:
        """
        if nonpadding is None:
            nonpadding = 1
        cond_sqz = self.g_pre_net(cond)
        if not infer:
            z_q, m_q, logs_q, nonpadding_sqz = self.encoder(x, nonpadding, cond_sqz)
            q_dist = dist.Normal(m_q, logs_q.exp())
            if self.use_prior_flow:
                logqx = q_dist.log_prob(z_q)
                z_p = self.prior_flow(z_q, nonpadding_sqz, cond_sqz)
                logpx = self.prior_dist.log_prob(z_p)
                loss_kl = ((logqx - logpx) * nonpadding_sqz).sum() / nonpadding_sqz.sum() / logqx.shape[1]
            else:
                loss_kl = torch.distributions.kl_divergence(q_dist, self.prior_dist)
                loss_kl = (loss_kl * nonpadding_sqz).sum() / nonpadding_sqz.sum() / z_q.shape[1]
                z_p = None
            return z_q, loss_kl, z_p, m_q, logs_q
        else:
            latent_shape = [cond_sqz.shape[0], self.latent_size, cond_sqz.shape[2]]
            z_p = torch.randn(latent_shape).to(cond.device) * noise_scale
            if self.use_prior_flow:
                z_p = self.prior_flow(z_p, 1, cond_sqz, reverse=True)
            return z_p