# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Pytorch neural network definitions."""

from typing import Sequence, Union

import torch
from torch import nn
import torch.nn.functional as F

from models.utils import Conv2dSamePadding


class ExtraConvBlock(nn.Module):
  """Additional convolution block."""

  def __init__(
      self,
      channel_dim,
      channel_multiplier,
  ):
    super().__init__()
    self.channel_dim = channel_dim
    self.channel_multiplier = channel_multiplier

    self.layer_norm = nn.LayerNorm(
        normalized_shape=channel_dim, elementwise_affine=True, bias=True
    )
    self.conv = nn.Conv2d(
        self.channel_dim,
        self.channel_dim * self.channel_multiplier,
        kernel_size=3,
        stride=1,
        padding=1,
    )
    self.conv_1 = nn.Conv2d(
        self.channel_dim * self.channel_multiplier,
        self.channel_dim,
        kernel_size=3,
        stride=1,
        padding=1,
    )

  def forward(self, x):
    x = self.layer_norm(x)
    x = x.permute(0, 3, 1, 2)
    res = self.conv(x)
    res = F.gelu(res, approximate='tanh')
    x = x + self.conv_1(res)
    x = x.permute(0, 2, 3, 1)
    return x


class ExtraConvs(nn.Module):
  """Additional CNN."""

  def __init__(
      self,
      num_layers=5,
      channel_dim=256,
      channel_multiplier=4,
  ):
    super().__init__()
    self.num_layers = num_layers
    self.channel_dim = channel_dim
    self.channel_multiplier = channel_multiplier

    self.blocks = nn.ModuleList()
    for _ in range(self.num_layers):
      self.blocks.append(
          ExtraConvBlock(self.channel_dim, self.channel_multiplier)
      )

  def forward(self, x):
    for block in self.blocks:
      x = block(x)

    return x


class ConvChannelsMixer(nn.Module):
  """Linear activation block for PIPs's MLP Mixer."""

  def __init__(self, in_channels):
    super().__init__()
    self.mlp2_up = nn.Linear(in_channels, in_channels * 4)
    self.mlp2_down = nn.Linear(in_channels * 4, in_channels)

  def forward(self, x):
    x = self.mlp2_up(x)
    x = F.gelu(x, approximate='tanh')
    x = self.mlp2_down(x)
    return x


class PIPsConvBlock(nn.Module):
  """Convolutional block for PIPs's MLP Mixer."""

  def __init__(
      self, in_channels, kernel_shape=3, use_causal_conv=False, block_idx=None
  ):
    super().__init__()
    self.use_causal_conv = use_causal_conv
    self.block_name = f'block_{block_idx}'
    self.kernel_shape = kernel_shape

    self.layer_norm = nn.LayerNorm(
        normalized_shape=in_channels, elementwise_affine=True, bias=False
    )
    self.mlp1_up = nn.Conv1d(
        in_channels,
        in_channels * 4,
        kernel_shape,
        stride=1,
        padding=0 if self.use_causal_conv else 1,
        groups=in_channels,
    )

    self.mlp1_up_1 = nn.Conv1d(
        in_channels * 4,
        in_channels * 4,
        kernel_shape,
        stride=1,
        padding=0 if self.use_causal_conv else 1,
        groups=in_channels * 4,
    )
    self.layer_norm_1 = nn.LayerNorm(
        normalized_shape=in_channels, elementwise_affine=True, bias=False
    )
    self.conv_channels_mixer = ConvChannelsMixer(in_channels)

  def forward(self, x, causal_context=None, get_causal_context=False):
    to_skip = x
    x = self.layer_norm(x)
    new_causal_context = {}
    num_extra = 0

    if causal_context is not None:
      name1 = self.block_name + '_causal_1'
      x = torch.cat([causal_context[name1], x], dim=-2)
      num_extra = causal_context[name1].shape[-2]
      new_causal_context[name1] = x[..., -(self.kernel_shape - 1) :, :]

    x = x.permute(0, 2, 1)
    if self.use_causal_conv:
      x = F.pad(x, (2, 0))
    x = self.mlp1_up(x)

    x = F.gelu(x, approximate='tanh')

    if causal_context is not None:
      x = x.permute(0, 2, 1)
      name2 = self.block_name + '_causal_2'
      num_extra = causal_context[name2].shape[-2]
      x = torch.cat([causal_context[name2], x[..., num_extra:, :]], dim=-2)
      new_causal_context[name2] = x[..., -(self.kernel_shape - 1) :, :]
      x = x.permute(0, 2, 1)

    if self.use_causal_conv:
      x = F.pad(x, (2, 0))
    x = self.mlp1_up_1(x)
    x = x.permute(0, 2, 1)

    if causal_context is not None:
      x = x[..., num_extra:, :]

    x = x[..., 0::4] + x[..., 1::4] + x[..., 2::4] + x[..., 3::4]

    x = x + to_skip
    to_skip = x
    x = self.layer_norm_1(x)
    x = self.conv_channels_mixer(x)

    x = x + to_skip
    return x, new_causal_context


class PIPSMLPMixer(nn.Module):
  """Depthwise-conv version of PIPs's MLP Mixer."""

  def __init__(
      self,
      input_channels: int,
      output_channels: int,
      hidden_dim: int = 512,
      num_blocks: int = 12,
      kernel_shape: int = 3,
      use_causal_conv: bool = False,
  ):
    """Inits Mixer module.

    A depthwise-convolutional version of a MLP Mixer for processing images.

    Args:
        input_channels (int): The number of input channels.
        output_channels (int): The number of output channels.
        hidden_dim (int, optional): The dimension of the hidden layer. Defaults
          to 512.
        num_blocks (int, optional): The number of convolution blocks in the
          mixer. Defaults to 12.
        kernel_shape (int, optional): The size of the kernel in the convolution
          blocks. Defaults to 3.
        use_causal_conv (bool, optional): Whether to use causal convolutions.
          Defaults to False.
    """

    super().__init__()
    self.hidden_dim = hidden_dim
    self.num_blocks = num_blocks
    self.use_causal_conv = use_causal_conv
    self.linear = nn.Linear(input_channels, self.hidden_dim)
    self.layer_norm = nn.LayerNorm(
        normalized_shape=hidden_dim, elementwise_affine=True, bias=False
    )
    self.linear_1 = nn.Linear(hidden_dim, output_channels)
    self.blocks = nn.ModuleList([
        PIPsConvBlock(
            hidden_dim, kernel_shape, self.use_causal_conv, block_idx=i
        )
        for i in range(num_blocks)
    ])

  def forward(self, x, causal_context=None, get_causal_context=False):
    x = self.linear(x)
    all_causal_context = {}
    for block in self.blocks:
      x, new_causal_context = block(x, causal_context, get_causal_context)
      if get_causal_context:
        all_causal_context.update(new_causal_context)

    x = self.layer_norm(x)
    x = self.linear_1(x)
    return x, all_causal_context


class BlockV2(nn.Module):
  """ResNet V2 block."""

  def __init__(
      self,
      channels_in: int,
      channels_out: int,
      stride: Union[int, Sequence[int]],
      use_projection: bool,
  ):
    super().__init__()
    self.padding = (1, 1, 1, 1)
    # Handle assymetric padding created by padding="SAME" in JAX/LAX
    if stride == 1:
      self.padding = (1, 1, 1, 1)
    elif stride == 2:
      self.padding = (0, 2, 0, 2)
    else:
      raise ValueError(
          'Check correct padding using padtype_to_padsin jax._src.lax.lax'
      )

    self.use_projection = use_projection
    if self.use_projection:
      self.proj_conv = Conv2dSamePadding(
          in_channels=channels_in,
          out_channels=channels_out,
          kernel_size=1,
          stride=stride,
          padding=0,
          bias=False,
      )

    self.bn_0 = nn.InstanceNorm2d(
        num_features=channels_in,
        eps=1e-05,
        momentum=0.1,
        affine=True,
        track_running_stats=False,
    )
    self.conv_0 = Conv2dSamePadding(
        in_channels=channels_in,
        out_channels=channels_out,
        kernel_size=3,
        stride=stride,
        padding=0,
        bias=False,
    )

    self.conv_1 = Conv2dSamePadding(
        in_channels=channels_out,
        out_channels=channels_out,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
    )
    self.bn_1 = nn.InstanceNorm2d(
        num_features=channels_out,
        eps=1e-05,
        momentum=0.1,
        affine=True,
        track_running_stats=False,
    )

  def forward(self, inputs):
    x = shortcut = inputs

    x = self.bn_0(x)
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)

    x = self.conv_0(x)

    x = self.bn_1(x)
    x = torch.relu(x)
    # no issues with padding here as this layer always has stride 1
    x = self.conv_1(x)

    return x + shortcut


class BlockGroup(nn.Module):
  """Higher level block for ResNet implementation."""

  def __init__(
      self,
      channels_in: int,
      channels_out: int,
      num_blocks: int,
      stride: Union[int, Sequence[int]],
      use_projection: bool,
  ):
    super().__init__()
    blocks = []
    for i in range(num_blocks):
      blocks.append(
          BlockV2(
              channels_in=channels_in if i == 0 else channels_out,
              channels_out=channels_out,
              stride=(1 if i else stride),
              use_projection=(i == 0 and use_projection),
          )
      )
    self.blocks = nn.ModuleList(blocks)

  def forward(self, inputs):
    out = inputs
    for block in self.blocks:
      out = block(out)
    return out


class ResNet(nn.Module):
  """ResNet model."""

  def __init__(
      self,
      blocks_per_group: Sequence[int],
      channels_per_group: Sequence[int] = (64, 128, 256, 512),
      use_projection: Sequence[bool] = (True, True, True, True),
      strides: Sequence[int] = (1, 2, 2, 2),
  ):
    """Initializes a ResNet model with customizable layers and configurations.

    This constructor allows defining the architecture of a ResNet model by
    setting the number of blocks, channels, projection usage, and strides for
    each group of blocks within the network. It provides flexibility in
    creating various ResNet configurations.

    Args:
      blocks_per_group: A sequence of 4 integers, each indicating the number
        of residual blocks in each group.
      channels_per_group: A sequence of 4 integers, each specifying the number
        of output channels for the blocks in each group. Defaults to (64, 128,
        256, 512).
      use_projection: A sequence of 4 booleans, each indicating whether to use
        a projection shortcut (True) or an identity shortcut (False) in each
        group. Defaults to (True, True, True, True).
      strides: A sequence of 4 integers, each specifying the stride size for
        the convolutions in each group. Defaults to (1, 2, 2, 2).

    The ResNet model created will have 4 groups, with each group's
    architecture defined by the corresponding elements in these sequences.
    """
    super().__init__()

    self.initial_conv = Conv2dSamePadding(
        in_channels=3,
        out_channels=channels_per_group[0],
        kernel_size=(7, 7),
        stride=2,
        padding=0,
        bias=False,
    )

    block_groups = []
    for i, _ in enumerate(strides):
      block_groups.append(
          BlockGroup(
              channels_in=channels_per_group[i - 1] if i > 0 else 64,
              channels_out=channels_per_group[i],
              num_blocks=blocks_per_group[i],
              stride=strides[i],
              use_projection=use_projection[i],
          )
      )
    self.block_groups = nn.ModuleList(block_groups)

  def forward(self, inputs):
    result = {}
    out = inputs
    out = self.initial_conv(out)
    result['initial_conv'] = out

    for block_id, block_group in enumerate(self.block_groups):
      out = block_group(out)
      result[f'resnet_unit_{block_id}'] = out

    return result