""" The code is based on https://github.com/apple/ml-gsn/ with adaption. """

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.torch_utils.ops.native_ops import (
    FusedLeakyReLU,
    fused_leaky_relu,
    upfirdn2d,
)


class DiscriminatorHead(nn.Module):
    def __init__(self, in_channel, disc_stddev=False):
        super().__init__()

        self.disc_stddev = disc_stddev
        stddev_dim = 1 if disc_stddev else 0

        self.conv_stddev = ConvLayer2d(
            in_channel=in_channel + stddev_dim,
            out_channel=in_channel,
            kernel_size=3,
            activate=True
        )

        self.final_linear = nn.Sequential(
            nn.Flatten(),
            EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True),
            EqualLinear(in_channel=in_channel, out_channel=1),
        )

    def cat_stddev(self, x, stddev_group=4, stddev_feat=1):
        perm = torch.randperm(len(x))
        inv_perm = torch.argsort(perm)

        batch, channel, height, width = x.shape
        x = x[perm
             ]    # shuffle inputs so that all views in a single trajectory don't get put together

        group = min(batch, stddev_group)
        stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)

        stddev = stddev[inv_perm]    # reorder inputs
        x = x[inv_perm]

        out = torch.cat([x, stddev], 1)
        return out

    def forward(self, x):
        if self.disc_stddev:
            x = self.cat_stddev(x)
        x = self.conv_stddev(x)
        out = self.final_linear(x)
        return out


class ConvDecoder(nn.Module):
    def __init__(self, in_channel, out_channel, in_res, out_res):
        super().__init__()

        log_size_in = int(math.log(in_res, 2))
        log_size_out = int(math.log(out_res, 2))

        self.layers = []
        in_ch = in_channel
        for i in range(log_size_in, log_size_out):
            out_ch = in_ch // 2
            self.layers.append(
                ConvLayer2d(
                    in_channel=in_ch,
                    out_channel=out_ch,
                    kernel_size=3,
                    upsample=True,
                    bias=True,
                    activate=True
                )
            )
            in_ch = out_ch

        self.layers.append(
            ConvLayer2d(
                in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False
            )
        )
        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.layers(x)


class StyleDiscriminator(nn.Module):
    def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs):
        super().__init__()

        log_size_in = int(math.log(in_res, 2))
        log_size_out = int(math.log(4, 2))

        self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3)

        # each resblock will half the resolution and double the number of features (until a maximum of ch_max)
        self.layers = []
        in_channels = ch_mul
        for i in range(log_size_in, log_size_out, -1):
            out_channels = int(min(in_channels * 2, ch_max))
            self.layers.append(
                ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)
            )
            in_channels = out_channels
        self.layers = nn.Sequential(*self.layers)

        self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True)

    def forward(self, x):
        x = self.conv_in(x)
        x = self.layers(x)
        out = self.disc_out(x)
        return out


def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)

    if k.ndim == 1:
        k = k[None, :] * k[:, None]

    k /= k.sum()

    return k


class Blur(nn.Module):
    """Blur layer.

    Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after
    convolutional upsampling or before convolutional downsampling helps produces models that are more robust to
    shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide
    cleaner gradients, and therefore more stable training.

    Args:
    ----
    kernel: list, int
        A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1].
    pad: tuple, int
        A tuple of integers representing the number of rows/columns of padding to be added to the top/left and
        the bottom/right respectively.
    upsample_factor: int
        Upsample factor.

    """
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()

        kernel = make_kernel(kernel)

        if upsample_factor > 1:
            kernel = kernel * (upsample_factor**2)

        self.register_buffer("kernel", kernel)
        self.pad = pad

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)
        return out


class Upsample(nn.Module):
    """Upsampling layer.

    Perform upsampling using a blur kernel.

    Args:
    ----
    kernel: list, int
        A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1].
    factor: int
        Upsampling factor.

    """
    def __init__(self, kernel=[1, 3, 3, 1], factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel) * (factor**2)
        self.register_buffer("kernel", kernel)

        p = kernel.shape[0] - factor
        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2
        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
        return out


class Downsample(nn.Module):
    """Downsampling layer.

    Perform downsampling using a blur kernel.

    Args:
    ----
    kernel: list, int
        A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1].
    factor: int
        Downsampling factor.

    """
    def __init__(self, kernel=[1, 3, 3, 1], factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel)
        self.register_buffer("kernel", kernel)

        p = kernel.shape[0] - factor
        pad0 = (p + 1) // 2
        pad1 = p // 2
        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
        return out


class EqualLinear(nn.Module):
    """Linear layer with equalized learning rate.

    During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
    prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
    activation functions.

    Args:
    ----
    in_channel: int
        Input channels.
    out_channel: int
        Output channels.
    bias: bool
        Use bias term.
    bias_init: float
        Initial value for the bias.
    lr_mul: float
        Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of
        the gradients, effectively increasing/decreasing the learning rate for this layer.
    activate: bool
        Apply leakyReLU activation.

    """
    def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init))
        else:
            self.bias = None

        self.activate = activate
        self.scale = (1 / math.sqrt(in_channel)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activate:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)
        else:
            out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
        return out

    def __repr__(self):
        return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"


class EqualConv2d(nn.Module):
    """2D convolution layer with equalized learning rate.

    During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
    prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
    activation functions.

    Args:
    ----
    in_channel: int
        Input channels.
    out_channel: int
        Output channels.
    kernel_size: int
        Kernel size.
    stride: int
        Stride of convolutional kernel across the input.
    padding: int
        Amount of zero padding applied to both sides of the input.
    bias: bool
        Use bias term.

    """
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
        self.scale = 1 / math.sqrt(in_channel * kernel_size**2)

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))
        else:
            self.bias = None

    def forward(self, input):
        out = F.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding
        )
        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
            f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
        )


class EqualConvTranspose2d(nn.Module):
    """2D transpose convolution layer with equalized learning rate.

    During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
    prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
    activation functions.

    Args:
    ----
    in_channel: int
        Input channels.
    out_channel: int
        Output channels.
    kernel_size: int
        Kernel size.
    stride: int
        Stride of convolutional kernel across the input.
    padding: int
        Amount of zero padding applied to both sides of the input.
    output_padding: int
        Extra padding added to input to achieve the desired output size.
    bias: bool
        Use bias term.

    """
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        stride=1,
        padding=0,
        output_padding=0,
        bias=True
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
        self.scale = 1 / math.sqrt(in_channel * kernel_size**2)

        self.stride = stride
        self.padding = padding
        self.output_padding = output_padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))
        else:
            self.bias = None

    def forward(self, input):
        out = F.conv_transpose2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            output_padding=self.output_padding,
        )
        return out

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
            f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
        )


class ConvLayer2d(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size=3,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
    ):
        assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously'
        layers = []

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            layers.append(
                EqualConvTranspose2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=0,
                    stride=2,
                    bias=bias and not activate
                )
            )
            layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=0,
                    stride=2,
                    bias=bias and not activate
                )
            )

        if (not downsample) and (not upsample):
            padding = kernel_size // 2

            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=padding,
                    stride=1,
                    bias=bias and not activate
                )
            )

        if activate:
            layers.append(FusedLeakyReLU(out_channel, bias=bias))

        super().__init__(*layers)


class ConvResBlock2d(nn.Module):
    """2D convolutional residual block with equalized learning rate.

    Residual block composed of 3x3 convolutions and leaky ReLUs.

    Args:
    ----
    in_channel: int
        Input channels.
    out_channel: int
        Output channels.
    upsample: bool
        Apply upsampling via strided convolution in the first conv.
    downsample: bool
        Apply downsampling via strided convolution in the second conv.

    """
    def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
        super().__init__()

        assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously'
        mid_ch = in_channel if downsample else out_channel

        self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3)
        self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3)

        if (in_channel != out_channel) or upsample or downsample:
            self.skip = ConvLayer2d(
                in_channel,
                out_channel,
                upsample=upsample,
                downsample=downsample,
                kernel_size=1,
                activate=False,
                bias=False,
            )

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        if hasattr(self, 'skip'):
            skip = self.skip(input)
            out = (out + skip) / math.sqrt(2)
        else:
            out = (out + input) / math.sqrt(2)
        return out