# Copyright (c) 2022 Dominic Rampas MIT License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
from ...models.modeling_utils import ModelMixin
from ...models.vq_model import VQEncoderOutput
from ...utils.accelerate_utils import apply_forward_hook


class MixingResidualBlock(nn.Module):
    """
    Residual block with mixing used by Paella's VQ-VAE.
    """

    def __init__(self, inp_channels, embed_dim):
        super().__init__()
        # depthwise
        self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
        self.depthwise = nn.Sequential(
            nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels)
        )

        # channelwise
        self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
        self.channelwise = nn.Sequential(
            nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels)
        )

        self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)

    def forward(self, x):
        mods = self.gammas
        x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1]
        x = x + self.depthwise(x_temp) * mods[2]
        x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4]
        x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
        return x


class PaellaVQModel(ModelMixin, ConfigMixin):
    r"""VQ-VAE model from Paella model.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
    implements for all the model (such as downloading or saving, etc.)

    Parameters:
        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
        up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image.
        levels  (int, *optional*, defaults to 2): Number of levels in the model.
        bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model.
        embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model.
        latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model.
        num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE.
        scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space.
    """

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        up_down_scale_factor: int = 2,
        levels: int = 2,
        bottleneck_blocks: int = 12,
        embed_dim: int = 384,
        latent_channels: int = 4,
        num_vq_embeddings: int = 8192,
        scale_factor: float = 0.3764,
    ):
        super().__init__()

        c_levels = [embed_dim // (2**i) for i in reversed(range(levels))]
        # Encoder blocks
        self.in_block = nn.Sequential(
            nn.PixelUnshuffle(up_down_scale_factor),
            nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1),
        )
        down_blocks = []
        for i in range(levels):
            if i > 0:
                down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
            block = MixingResidualBlock(c_levels[i], c_levels[i] * 4)
            down_blocks.append(block)
        down_blocks.append(
            nn.Sequential(
                nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(latent_channels),  # then normalize them to have mean 0 and std 1
            )
        )
        self.down_blocks = nn.Sequential(*down_blocks)

        # Vector Quantizer
        self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25)

        # Decoder blocks
        up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))]
        for i in range(levels):
            for j in range(bottleneck_blocks if i == 0 else 1):
                block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
                up_blocks.append(block)
            if i < levels - 1:
                up_blocks.append(
                    nn.ConvTranspose2d(
                        c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1
                    )
                )
        self.up_blocks = nn.Sequential(*up_blocks)
        self.out_block = nn.Sequential(
            nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1),
            nn.PixelShuffle(up_down_scale_factor),
        )

    @apply_forward_hook
    def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
        h = self.in_block(x)
        h = self.down_blocks(h)

        if not return_dict:
            return (h,)

        return VQEncoderOutput(latents=h)

    @apply_forward_hook
    def decode(
        self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        if not force_not_quantize:
            quant, _, _ = self.vquantizer(h)
        else:
            quant = h

        x = self.up_blocks(quant)
        dec = self.out_block(x)
        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

    def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
        r"""
        Args:
            sample (`torch.FloatTensor`): Input sample.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        """
        x = sample
        h = self.encode(x).latents
        dec = self.decode(h).sample

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)