from typing import Optional

import torch
import torch.nn as nn
from timm.models.layers import DropPath
from torch import Tensor


class PointsEncoder(nn.Module):
    def __init__(self, feat_channel, encoder_channel):
        super().__init__()
        self.encoder_channel = encoder_channel
        self.first_mlp = nn.Sequential(
            nn.Linear(feat_channel, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
        )
        self.second_mlp = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, self.encoder_channel),
        )

    def forward(self, x, mask=None):
        """
        x : B M 3
        mask: B M
        -----------------
        feature_global : B C
        """

        bs, n, _ = x.shape
        device = x.device

        x_valid = self.first_mlp(x[mask].to(torch.float32))  # B n 256
        x_features = torch.zeros(bs, n, 256, device=device)
        x_features[mask] = x_valid

        pooled_feature = x_features.max(dim=1)[0]
        x_features = torch.cat(
            [x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1
        )

        x_features_valid = self.second_mlp(x_features[mask])
        res = torch.zeros(bs, n, self.encoder_channel, device=device)
        res[mask] = x_features_valid

        res = res.max(dim=1)[0]
        return res


class MapEncoder(nn.Module):
    def __init__(
            self,
            polygon_channel=6,
            dim=128,
    ) -> None:
        super().__init__()

        self.dim = dim
        self.polygon_encoder = PointsEncoder(polygon_channel, dim)
        self.speed_limit_emb = nn.Sequential(
            nn.Linear(1, dim), nn.ReLU(), nn.Linear(dim, dim)
        )
        self.type_emb = nn.Embedding(3, dim)

    def forward(self, data) -> torch.Tensor:
        polygon_center = data["polygon_center"]
        polygon_type = data["polygon_type"].long()
        point_position = data["point_position"]
        point_vector = data["point_vector"]
        point_orientation = data["point_orientation"]
        valid_mask = data["valid_mask"]

        polygon_feature = torch.cat(
            [
                point_position[:, :, 0] - polygon_center[..., None, :2],
                point_vector[:, :, 0],
                torch.stack(
                    [
                        point_orientation[:, :, 0].cos(),
                        point_orientation[:, :, 0].sin(),
                    ],
                    dim=-1,
                ),
            ],
            dim=-1,
        )

        bs, M, P, C = polygon_feature.shape
        valid_mask = valid_mask.view(bs * M, P)
        polygon_feature = polygon_feature.reshape(bs * M, P, C)

        x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1)

        x_type = self.type_emb(polygon_type)

        x_polygon += x_type

        return x_polygon


class AgentEncoder(nn.Module):
    def __init__(
            self,
            agent_channel=9,
            dim=128,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.first_mlp = nn.Sequential(
            nn.Linear(agent_channel, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
        )
        self.second_mlp = nn.Sequential(
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, self.dim),
        )
        self.type_emb = nn.Embedding(4, dim)

    def forward(self, data):
        category = data["categories"].long()
        agent_feature = data["states"]
        valid_mask = data["valid_mask"]

        bs, A, _ = agent_feature.shape
        agent_feature = self.second_mlp(
            self.first_mlp(
                agent_feature[valid_mask].to(torch.float32)
            )
        )  # B, A, C
        res = torch.zeros(bs, A, self.dim,
                          device=agent_feature.device,
                          dtype=agent_feature.dtype)
        res[valid_mask] = agent_feature

        x_type = self.type_emb(category)
        return res + x_type

class Mlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class CustomTransformerEncoderLayer(nn.Module):
    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.0,
            qkv_bias=False,
            drop=0.0,
            attn_drop=0.0,
            drop_path=0.0,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = torch.nn.MultiheadAttention(
            dim,
            num_heads=num_heads,
            add_bias_kv=qkv_bias,
            dropout=attn_drop,
            batch_first=True,
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(
            self,
            src,
            mask: Optional[Tensor] = None,
            key_padding_mask: Optional[Tensor] = None,
    ):
        src2 = self.norm1(src)
        src2 = self.attn(
            query=src2,
            key=src2,
            value=src2,
            attn_mask=mask,
            key_padding_mask=key_padding_mask,
        )[0]
        src = src + self.drop_path1(src2)
        src = src + self.drop_path2(self.mlp(self.norm2(src)))
        return src