from typing import Optional import torch import torch.nn as nn from timm.models.layers import DropPath from torch import Tensor class PointsEncoder(nn.Module): def __init__(self, feat_channel, encoder_channel): super().__init__() self.encoder_channel = encoder_channel self.first_mlp = nn.Sequential( nn.Linear(feat_channel, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Linear(128, 256), ) self.second_mlp = nn.Sequential( nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, self.encoder_channel), ) def forward(self, x, mask=None): """ x : B M 3 mask: B M ----------------- feature_global : B C """ bs, n, _ = x.shape device = x.device x_valid = self.first_mlp(x[mask].to(torch.float32)) # B n 256 x_features = torch.zeros(bs, n, 256, device=device) x_features[mask] = x_valid pooled_feature = x_features.max(dim=1)[0] x_features = torch.cat( [x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1 ) x_features_valid = self.second_mlp(x_features[mask]) res = torch.zeros(bs, n, self.encoder_channel, device=device) res[mask] = x_features_valid res = res.max(dim=1)[0] return res class MapEncoder(nn.Module): def __init__( self, polygon_channel=6, dim=128, ) -> None: super().__init__() self.dim = dim self.polygon_encoder = PointsEncoder(polygon_channel, dim) self.speed_limit_emb = nn.Sequential( nn.Linear(1, dim), nn.ReLU(), nn.Linear(dim, dim) ) self.type_emb = nn.Embedding(3, dim) def forward(self, data) -> torch.Tensor: polygon_center = data["polygon_center"] polygon_type = data["polygon_type"].long() point_position = data["point_position"] point_vector = data["point_vector"] point_orientation = data["point_orientation"] valid_mask = data["valid_mask"] polygon_feature = torch.cat( [ point_position[:, :, 0] - polygon_center[..., None, :2], point_vector[:, :, 0], torch.stack( [ point_orientation[:, :, 0].cos(), point_orientation[:, :, 0].sin(), ], dim=-1, ), ], dim=-1, ) bs, M, P, C = polygon_feature.shape valid_mask = valid_mask.view(bs * M, P) polygon_feature = polygon_feature.reshape(bs * M, P, C) x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1) x_type = self.type_emb(polygon_type) x_polygon += x_type return x_polygon class AgentEncoder(nn.Module): def __init__( self, agent_channel=9, dim=128, ) -> None: super().__init__() self.dim = dim self.first_mlp = nn.Sequential( nn.Linear(agent_channel, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Linear(128, 256), ) self.second_mlp = nn.Sequential( nn.Linear(256, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, self.dim), ) self.type_emb = nn.Embedding(4, dim) def forward(self, data): category = data["categories"].long() agent_feature = data["states"] valid_mask = data["valid_mask"] bs, A, _ = agent_feature.shape agent_feature = self.second_mlp( self.first_mlp( agent_feature[valid_mask].to(torch.float32) ) ) # B, A, C res = torch.zeros(bs, A, self.dim, device=agent_feature.device, dtype=agent_feature.dtype) res[valid_mask] = agent_feature x_type = self.type_emb(category) return res + x_type class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class CustomTransformerEncoderLayer(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = torch.nn.MultiheadAttention( dim, num_heads=num_heads, add_bias_kv=qkv_bias, dropout=attn_drop, batch_first=True, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward( self, src, mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, ): src2 = self.norm1(src) src2 = self.attn( query=src2, key=src2, value=src2, attn_mask=mask, key_padding_mask=key_padding_mask, )[0] src = src + self.drop_path1(src2) src = src + self.drop_path2(self.mlp(self.norm2(src))) return src