"""
Implements the TransFuser vision backbone.
"""

import math
import torch
from torch import nn
import torch.nn.functional as F
import timm
import copy
from torch.utils.checkpoint import checkpoint
from torchvision.transforms import Resize

from navsim.agents.backbones.eva import EVAViT
from navsim.agents.backbones.vov import VoVNet
from navsim.agents.transfuser.transfuser_backbone import GPT
from timm.models.vision_transformer import VisionTransformer

from navsim.agents.utils.vit import DAViT


class TransfuserBackboneMoeUlt32(nn.Module):
    """
    Multi-scale Fusion Transformer for image + LiDAR feature fusion
    """

    def __init__(self, config):

        super().__init__()
        self.config = config

        # debug
        # vit-l
        img_vit_size = (256, 256 * 4)
        self.resize = Resize(img_vit_size)
        self.image_encoder = nn.ModuleDict({
            'davit': DAViT(ckpt=config.vit_ckpt),
            'sptrvit': EVAViT(
                img_size=img_vit_size[0],  # img_size for short side
                patch_size=16,
                window_size=16,
                global_window_size=img_vit_size[0] // 16,
                # If use square image (e.g., set global_window_size=0, else global_window_size=img_size // 16)
                in_chans=3,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                mlp_ratio=4 * 2 / 3,
                window_block_indexes=(
                        list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(
                    range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
                ),
                qkv_bias=True,
                drop_path_rate=0.3,
                with_cp=True,
                flash_attn=False,
                xformers_attn=True
            ),
            'mapvit': EVAViT(
                img_size=img_vit_size[0],  # img_size for short side
                patch_size=16,
                window_size=16,
                global_window_size=img_vit_size[0] // 16,
                # If use square image (e.g., set global_window_size=0, else global_window_size=img_size // 16)
                in_chans=3,
                embed_dim=1024,
                depth=24,
                num_heads=16,
                mlp_ratio=4 * 2 / 3,
                window_block_indexes=(
                        list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(
                    range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
                ),
                qkv_bias=True,
                drop_path_rate=0.3,
                with_cp=True,
                flash_attn=False,
                xformers_attn=True
            ),
            'vov': VoVNet(
                spec_name='V-99-eSE',
                out_features=['stage4', 'stage5'],
                norm_eval=True,
                with_cp=True,
                init_cfg=dict(
                    type='Pretrained',
                    checkpoint=config.vov_ckpt,
                    prefix='img_backbone.'
                )
            )
        })
        self.image_encoder['sptrvit'].init_weights(config.sptr_ckpt)
        self.image_encoder['mapvit'].init_weights(config.map_ckpt)
        self.image_encoder['vov'].init_weights()
        self.moe_proj = nn.Sequential(*[
            nn.Conv2d(1024 * 4, 1024, 1)
        ])

        if config.use_ground_plane:
            in_channels = 2 * config.lidar_seq_len
        else:
            in_channels = config.lidar_seq_len

        self.avgpool_img = nn.AdaptiveAvgPool2d(
            (self.config.img_vert_anchors, self.config.img_horz_anchors)
        )

        self.lidar_encoder = timm.create_model(
            config.lidar_architecture,
            pretrained=False,
            in_chans=in_channels,
            features_only=True,
        )
        self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1)
        self.avgpool_lidar = nn.AdaptiveAvgPool2d(
            (self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)
        )
        lidar_time_frames = [1, 1, 1, 1]

        self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1)
        start_index = 0
        # Some networks have a stem layer
        vit_channels = 1024
        if len(self.lidar_encoder.return_layers) > 4:
            start_index += 1

        self.transformers = nn.ModuleList(
            [
                GPT(
                    n_embd=vit_channels,
                    config=config,
                    # lidar_video=self.lidar_video,
                    lidar_time_frames=lidar_time_frames[i],
                )
                for i in range(4)
            ]
        )
        self.lidar_channel_to_img = nn.ModuleList(
            [
                nn.Conv2d(
                    self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
                    vit_channels,
                    kernel_size=1,
                )
                for i in range(4)
            ]
        )
        self.img_channel_to_lidar = nn.ModuleList(
            [
                nn.Conv2d(
                    vit_channels,
                    self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
                    kernel_size=1,
                )
                for i in range(4)
            ]
        )

        self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"]
        # FPN fusion
        channel = self.config.bev_features_channels
        self.relu = nn.ReLU(inplace=True)
        # top down
        if self.config.detect_boxes or self.config.use_bev_semantic:
            self.upsample = nn.Upsample(
                scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False
            )
            self.upsample2 = nn.Upsample(
                size=(
                    self.config.lidar_resolution_height // self.config.bev_down_sample_factor,
                    self.config.lidar_resolution_width // self.config.bev_down_sample_factor,
                ),
                mode="bilinear",
                align_corners=False,
            )

            self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1)
            self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1)

            # lateral
            self.c5_conv = nn.Conv2d(
                self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1)
            )

    def top_down(self, x):

        p5 = self.relu(self.c5_conv(x))
        p4 = self.relu(self.up_conv5(self.upsample(p5)))
        p3 = self.relu(self.up_conv4(self.upsample2(p4)))

        return p3

    def forward(self, image, lidar):
        """
        Image + LiDAR feature fusion using transformers
        Args:
            image_list (list): list of input images
            lidar_list (list): list of input LiDAR BEV
        """
        image_features, lidar_features = image, lidar

        # Generate an iterator for all the layers in the network that one can loop through.
        lidar_layers = iter(self.lidar_encoder.items())

        # Stem layer.
        # In some architectures the stem is not a return layer, so we need to skip it.
        if len(self.lidar_encoder.return_layers) > 4:
            lidar_features = self.forward_layer_block(
                lidar_layers, self.lidar_encoder.return_layers, lidar_features
            )

        # 16 * 64
        vov_features = self.image_encoder['vov'](image_features)[-1]
        # resize 512 * 2048 -> 256 * 1024
        image_features_resized = self.resize(image_features)

        # 16 * 64
        sptr_features = self.image_encoder['sptrvit'](image_features_resized)[0]
        map_features = self.image_encoder['mapvit'](image_features_resized)[0]
        davit_features = self.image_encoder['davit'](image_features_resized)[0]

        final_features = torch.cat([
            self.avgpool_img(vov_features),
            self.avgpool_img(sptr_features),
            self.avgpool_img(map_features),
            self.avgpool_img(davit_features)
        ], dim=1)

        image_features = self.moe_proj(final_features)
        for i in range(4):
            lidar_features = self.forward_layer_block(
                lidar_layers, self.lidar_encoder.return_layers, lidar_features
            )

            image_features, lidar_features = self.fuse_features(image_features, lidar_features, i)

        if self.config.detect_boxes or self.config.use_bev_semantic:
            x4 = lidar_features

        # image_feature_grid = None
        # if self.config.use_semantic or self.config.use_depth:
        #     image_feature_grid = image_features

        if self.config.transformer_decoder_join:
            fused_features = lidar_features
        else:
            image_features = self.global_pool_img(image_features)
            image_features = torch.flatten(image_features, 1)
            lidar_features = self.global_pool_lidar(lidar_features)
            lidar_features = torch.flatten(lidar_features, 1)

            if self.config.add_features:
                lidar_features = self.lidar_to_img_features_end(lidar_features)
                fused_features = image_features + lidar_features
            else:
                fused_features = torch.cat((image_features, lidar_features), dim=1)

        if self.config.detect_boxes or self.config.use_bev_semantic:
            features = self.top_down(x4)
        else:
            features = None


        return features, fused_features, image_features

    def forward_layer_block(self, layers, return_layers, features, if_ckpt=False):
        """
        Run one forward pass to a block of layers from a TIMM neural network and returns the result.
        Advances the whole network by just one block
        :param layers: Iterator starting at the current layer block
        :param return_layers: TIMM dictionary describing at which intermediate layers features are returned.
        :param features: Input features
        :return: Processed features
        """
        for name, module in layers:
            if if_ckpt:
                features = checkpoint(module, features)
            else:
                features = module(features)
            if name in return_layers:
                break
        return features

    def fuse_features(self, image_features, lidar_features, layer_idx):
        """
        Perform a TransFuser feature fusion block using a Transformer module.
        :param image_features: Features from the image branch
        :param lidar_features: Features from the LiDAR branch
        :param layer_idx: Transformer layer index.
        :return: image_features and lidar_features with added features from the other branch.
        """
        image_embd_layer = image_features
        lidar_embd_layer = self.avgpool_lidar(lidar_features)

        lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer)

        image_features_layer, lidar_features_layer = self.transformers[layer_idx](
            image_embd_layer, lidar_embd_layer
        )
        lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer)

        image_features_layer = F.interpolate(
            image_features_layer,
            size=(image_features.shape[2], image_features.shape[3]),
            mode="bilinear",
            align_corners=False,
        )
        lidar_features_layer = F.interpolate(
            lidar_features_layer,
            size=(lidar_features.shape[2], lidar_features.shape[3]),
            mode="bilinear",
            align_corners=False,
        )

        image_features = image_features + image_features_layer
        lidar_features = lidar_features + lidar_features_layer

        return image_features, lidar_features