from typing import Dict

import numpy as np
import torch
import torch.nn as nn

from navsim.agents.hydra.hydra_backbone_pe import HydraBackbonePE
from navsim.agents.hydra.hydra_config import HydraConfig
from navsim.agents.transfuser.transfuser_model import AgentHead
from navsim.agents.utils.attn import MemoryEffTransformer
from navsim.agents.utils.nerf import nerf_positional_encoding
from navsim.agents.vadv2.vadv2_config import Vadv2Config
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
from navsim.agents.utils.positional_encoding import SinePositionalEncoding3D
from mmcv.cnn import Conv2d
class HydraDetModelPE(nn.Module):
    def __init__(self, config: HydraConfig):
        super().__init__()

        self._query_splits = [
            config.num_bounding_boxes,
        ]

        self._config = config
        assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin']
        if config.backbone_type == 'vit' or config.backbone_type == 'eva':
            raise ValueError(f'{config.backbone_type} not supported')
        elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or config.backbone_type == 'swin' \
                or config.backbone_type == 'resnet':
            self._backbone = HydraBackbonePE(config)

        self._keyval_embedding = nn.Embedding(
            config.img_vert_anchors * config.img_horz_anchors, config.tf_d_model
        )  # 8x8 feature grid + trajectory
        self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)

        # usually, the BEV features are variable in size.
        self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1)
        self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model)

        self.depth_num = 64
        self.depth_start = 1
        self.position_range = [-32.0, -32.0, -10.0, 32.0, 32.0, 10.0]
        self.position_dim = 3 * self.depth_num
        self.embed_dims = 256
        self.sin_positional_encoding = dict(
            type='SinePositionalEncoding3D', num_feats=128, normalize=True)
        self.positional_encoding = build_positional_encoding(
            self.sin_positional_encoding)
        self.adapt_pos3d = nn.Sequential(
            nn.Conv2d(self.embed_dims*3//2, self.embed_dims*4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(self.embed_dims*4, self.embed_dims, kernel_size=1, stride=1, padding=0),
        )
        self.position_encoder = nn.Sequential(
            nn.Conv2d(self.position_dim, self.embed_dims * 4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(self.embed_dims * 4, self.embed_dims, kernel_size=1, stride=1, padding=0),
        )
        tf_decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.tf_d_model,
            nhead=config.tf_num_head,
            dim_feedforward=config.tf_d_ffn,
            dropout=config.tf_dropout,
            batch_first=True,
        )

        self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)
        self._agent_head = AgentHead(
            num_agents=config.num_bounding_boxes,
            d_ffn=config.tf_d_ffn,
            d_model=config.tf_d_model,
        )

        self._trajectory_head = HydraTrajHead(
            num_poses=config.trajectory_sampling.num_poses,
            d_ffn=config.tf_d_ffn,
            d_model=config.tf_d_model,
            nhead=config.vadv2_head_nhead,
            nlayers=config.vadv2_head_nlayers,
            vocab_path=config.vocab_path,
            config=config
        )

    def inverse_sigmoid(self, x, eps=1e-6):
        """Inverse sigmoid function.

        Args:
            x (Tensor): The input tensor.
            eps (float): A small value to avoid numerical issues.

        Returns:
            Tensor: The logit value of the input.
        """
        x = x.clamp(min=eps, max=1 - eps)  # Ensure the input is within the valid range
        return torch.log(x / (1 - x))

    def position_embedding(self, features, img_features):
        eps = 1e-5
        img_features = img_features.unsqueeze(1)
        B, N, C, tar_H, tar_W = img_features.shape
        device = img_features.device
        crop_top = 28
        crop_left = 416
        H = [self._config.img_vert_anchors for _ in range(3)]
        W = [
            self._config.img_horz_anchors * 1088 // (1088 * 2 + 1920),
            self._config.img_horz_anchors * 1920 // (1088 * 2 + 1920),
            self._config.img_horz_anchors * 1088 // (1088 * 2 + 1920)
        ]

        # 左视图(16,17)
        coords_h_l = torch.arange(H[0], device=device).float() * 1080 / H[0] + crop_top / H[0]
        coords_w_l = torch.arange(W[0], device=device).float() * 1920 / W[0] + crop_left / W[0]
        # 前视图(16,30)
        coords_h_f = torch.arange(H[1], device=device).float() * 1080 / H[1] + crop_top / H[1]
        coords_w_f = torch.arange(W[1], device=device).float() * 1920 / W[1]
        # 右视图(16,17)
        coords_h_r = torch.arange(H[2], device=device).float() * 1080 / H[2] + crop_top / H[2]
        coords_w_r = torch.arange(W[2], device=device).float() * 1920 / W[2] + crop_left / W[2]

        index = torch.arange(start=0, end=self.depth_num, step=1, device=img_features.device).float()
        index_1 = index + 1
        bin_size = (self.position_range[3] - self.depth_start) / (self.depth_num * (1 + self.depth_num))
        coords_d = self.depth_start + bin_size * index * index_1

        D = coords_d.shape[0]
        coords = [1] * 3  # 0,1,2 -> front, left, right
        coords[0] = torch.stack(torch.meshgrid([coords_w_l, coords_h_l, coords_d])).permute(1, 2, 3, 0)  # W, H, D, 3
        coords[1] = torch.stack(torch.meshgrid([coords_w_f, coords_h_f, coords_d])).permute(1, 2, 3, 0)  # W, H, D, 3
        coords[2] = torch.stack(torch.meshgrid([coords_w_r, coords_h_r, coords_d])).permute(1, 2, 3, 0)  # W, H, D, 3
        # coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)
        coords[0][..., :2] = coords[0][..., :2] * torch.max(coords[0][..., 2:3],
                                                            torch.ones_like(coords[0][..., 2:3]) * eps)
        coords[1][..., :2] = coords[1][..., :2] * torch.max(coords[1][..., 2:3],
                                                            torch.ones_like(coords[1][..., 2:3]) * eps)
        coords[2][..., :2] = coords[2][..., :2] * torch.max(coords[2][..., 2:3],
                                                            torch.ones_like(coords[2][..., 2:3]) * eps)

        # img_meta
        # img2lidars = ?
        pos_3d_embed = None
        for i in range(3):
            sensor2lidar_rotation = features["sensor2lidar_rotation"][i]
            sensor2lidar_translation = features["sensor2lidar_translation"][i]
            intrinsics = features["intrinsics"][i]
            combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)).float()  # (B, 1, 3, 3) ?
            # print(combine.shape)

            # coords_front,coords_fleft,coords_fright (W, H, D, 3)
            # coords3d = torch.stack((coords_front, coords_fleft, coords_fright), dim=0) # (N, W, H, D, 3) -> (B, N, W, H, D, 3, 1)
            # coords = coords.view(1, H, W, D, 1, 3).repeat(B, 1, 1, 1, 1, 1)
            coords3d = coords[i].view(1, N, W[i], H[i], D, 3, 1).repeat(B, 1, 1, 1, 1, 1,
                                                                        1)  # (B, N, W, H, D, 3, 1) -> (B, N, W, H, D, 3, 3)
            combine = combine.view(B, N, 1, 1, 1, 3, 3).repeat(1, 1, W[i], H[i], D, 1, 1)
            coords3d = torch.matmul(combine, coords3d).squeeze(-1)  # (B, N, W, H, D, 3)
            sensor2lidar_translation = sensor2lidar_translation.view(B, N, 1, 1, 1, 3)
            coords3d += sensor2lidar_translation

            coords3d[..., 0:1] = (coords3d[..., 0:1] - self.position_range[0]) / (
                    self.position_range[3] - self.position_range[0])
            coords3d[..., 1:2] = (coords3d[..., 1:2] - self.position_range[1]) / (
                    self.position_range[4] - self.position_range[1])
            coords3d[..., 2:3] = (coords3d[..., 2:3] - self.position_range[2]) / (
                    self.position_range[5] - self.position_range[2])
            # coords_mask = (coords3d > 1.0) | (coords3d < 0.0)
            # coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5)
            # coords_mask = coords_mask.permute(0, 1, 3, 2)
            #     for j in range(1000000):
            #         print(coords3d.shape)
            # (2, 1, 17, 16, 64, 3) -> (B, N, W, H, D, 3)
            # (2, 1, 30, 16, 64, 3)
            # -> (2, 1, 17+30+17, 16, 64, 3)
            # coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, H[i], W[i])
            if pos_3d_embed is None:
                pos_3d_embed = coords3d
            else:
                pos_3d_embed = torch.cat((pos_3d_embed, coords3d), dim=2)
        # for i in range(100000):
        # print(img_features.shape)
        pos_3d_embed = pos_3d_embed.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, tar_H, tar_W)
        coords3d = self.inverse_sigmoid(pos_3d_embed)
        coords_position_embeding = self.position_encoder(coords3d)
        return coords_position_embeding.view(B, N, self.embed_dims, tar_H, tar_W)

    def forward(self, features: Dict[str, torch.Tensor],
                interpolated_traj=None) -> Dict[str, torch.Tensor]:
        # Todo egostatus
        camera_feature: torch.Tensor = features["camera_feature"][0]
        # lidar_feature: torch.Tensor = features["lidar_feature"]
        status_feature: torch.Tensor = features["status_feature"]

        batch_size = status_feature.shape[0]
        assert (camera_feature.shape[0] == batch_size)
        img_features = self._backbone(camera_feature)
        img_features = self.downscale_layer(img_features)
        input_img_h, input_img_w = img_features.size(-2), img_features.size(-1)
        masks = img_features.new_ones(
            (img_features.shape[0], 1, input_img_h, input_img_w))

        coords_position_embeding = self.position_embedding(features, img_features)
        sin_embed = self.positional_encoding(masks)
        sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).view(img_features.size())
        pos_embed = coords_position_embeding.squeeze(1) + sin_embed

        pos_embed = pos_embed.flatten(-2, -1)
        pos_embed = pos_embed.permute(0, 2, 1)
        # img_features = img_features.copy()
        img_features = img_features  # (B, N, self.embed_dims, H, W)
        img_features = img_features.flatten(-2, -1)
        img_features = img_features.permute(0, 2, 1)

        if self._config.num_ego_status == 1 and status_feature.shape[1] == 32:
            status_encoding = self._status_encoding(status_feature[:, :8])
        else:
            status_encoding = self._status_encoding(status_feature)

        keyval = img_features

        keyval += self._keyval_embedding.weight[None, ...]

        query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
        agents_query = self._tf_decoder(query, keyval + pos_embed)

        output: Dict[str, torch.Tensor] = {}
        trajectory = self._trajectory_head(keyval, status_encoding, interpolated_traj)
        output.update(trajectory)
        agents = self._agent_head(agents_query)
        output.update(agents)

        return output


class HydraTrajHead(nn.Module):
    def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str,
                 nhead: int, nlayers: int, config: Vadv2Config = None
                 ):
        super().__init__()
        self._num_poses = num_poses
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model, nhead, d_ffn,
                dropout=0.0, batch_first=True
            ), nlayers
        )
        self.vocab = nn.Parameter(
            torch.from_numpy(np.load(vocab_path)),
            requires_grad=False
        )

        self.heads = nn.ModuleDict({
            'noc': nn.Sequential(
                nn.Linear(d_model, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, 1),
            ),
            'da':
                nn.Sequential(
                    nn.Linear(d_model, d_ffn),
                    nn.ReLU(),
                    nn.Linear(d_ffn, 1),
                ),
            'ttc': nn.Sequential(
                nn.Linear(d_model, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, 1),
            ),
            'comfort': nn.Sequential(
                nn.Linear(d_model, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, 1),
            ),
            'progress': nn.Sequential(
                nn.Linear(d_model, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, 1),
            ),
            'imi': nn.Sequential(
                nn.Linear(d_model, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, 1),
            )
        })

        self.inference_imi_weight = config.inference_imi_weight
        self.inference_da_weight = config.inference_da_weight
        self.normalize_vocab_pos = config.normalize_vocab_pos
        if self.normalize_vocab_pos:
            self.encoder = MemoryEffTransformer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_model * 4,
                dropout=0.0
            )
        self.use_nerf = config.use_nerf

        if self.use_nerf:
            self.pos_embed = nn.Sequential(
                nn.Linear(1040, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, d_model),
            )
        else:
            self.pos_embed = nn.Sequential(
                nn.Linear(num_poses * 3, d_ffn),
                nn.ReLU(),
                nn.Linear(d_ffn, d_model),
            )

    def forward(self, bev_feature, status_encoding, interpolated_traj) -> Dict[str, torch.Tensor]:
        # todo sinusoidal embedding
        # vocab: 4096, 40, 3
        # bev_feature: B, 32, C
        # embedded_vocab: B, 4096, C
        vocab = self.vocab.data
        L, HORIZON, _ = vocab.shape
        B = bev_feature.shape[0]
        if self.use_nerf:
            vocab = torch.cat(
                [
                    nerf_positional_encoding(vocab[..., :2]),
                    torch.cos(vocab[..., -1])[..., None],
                    torch.sin(vocab[..., -1])[..., None],
                ], dim=-1
            )

        if self.normalize_vocab_pos:
            embedded_vocab = self.pos_embed(vocab.view(L, -1))[None]
            embedded_vocab = self.encoder(embedded_vocab).repeat(B, 1, 1)
        else:
            embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1)
        tr_out = self.transformer(embedded_vocab, bev_feature)
        dist_status = tr_out + status_encoding.unsqueeze(1)
        result = {}
        # selected_indices: B,
        for k, head in self.heads.items():
            if k == 'imi':
                result[k] = head(dist_status).squeeze(-1)
            else:
                result[k] = head(dist_status).squeeze(-1).sigmoid()
        # how
        scores = (
                0.05 * result['imi'].softmax(-1).log() +
                0.5 * result['noc'].log() +
                0.5 * result['da'].log() +
                8.0 * (5 * result['ttc'] + 2 * result['comfort'] + 5 * result['progress']).log()
        )
        selected_indices = scores.argmax(1)
        result["trajectory"] = self.vocab.data[selected_indices]
        result["trajectory_vocab"] = self.vocab.data
        result["selected_indices"] = selected_indices
        return result