|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from timm.models.layers import DropPath |
|
from torch import Tensor |
|
|
|
|
|
class PointsEncoder(nn.Module): |
|
def __init__(self, feat_channel, encoder_channel): |
|
super().__init__() |
|
self.encoder_channel = encoder_channel |
|
self.first_mlp = nn.Sequential( |
|
nn.Linear(feat_channel, 128), |
|
nn.BatchNorm1d(128), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(128, 256), |
|
) |
|
self.second_mlp = nn.Sequential( |
|
nn.Linear(512, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(256, self.encoder_channel), |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
""" |
|
x : B M 3 |
|
mask: B M |
|
----------------- |
|
feature_global : B C |
|
""" |
|
|
|
bs, n, _ = x.shape |
|
device = x.device |
|
|
|
x_valid = self.first_mlp(x[mask].to(torch.float32)) |
|
x_features = torch.zeros(bs, n, 256, device=device) |
|
x_features[mask] = x_valid |
|
|
|
pooled_feature = x_features.max(dim=1)[0] |
|
x_features = torch.cat( |
|
[x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1 |
|
) |
|
|
|
x_features_valid = self.second_mlp(x_features[mask]) |
|
res = torch.zeros(bs, n, self.encoder_channel, device=device) |
|
res[mask] = x_features_valid |
|
|
|
res = res.max(dim=1)[0] |
|
return res |
|
|
|
|
|
class MapEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
polygon_channel=6, |
|
dim=128, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.polygon_encoder = PointsEncoder(polygon_channel, dim) |
|
self.speed_limit_emb = nn.Sequential( |
|
nn.Linear(1, dim), nn.ReLU(), nn.Linear(dim, dim) |
|
) |
|
self.type_emb = nn.Embedding(3, dim) |
|
|
|
def forward(self, data) -> torch.Tensor: |
|
polygon_center = data["polygon_center"] |
|
polygon_type = data["polygon_type"].long() |
|
point_position = data["point_position"] |
|
point_vector = data["point_vector"] |
|
point_orientation = data["point_orientation"] |
|
valid_mask = data["valid_mask"] |
|
|
|
polygon_feature = torch.cat( |
|
[ |
|
point_position[:, :, 0] - polygon_center[..., None, :2], |
|
point_vector[:, :, 0], |
|
torch.stack( |
|
[ |
|
point_orientation[:, :, 0].cos(), |
|
point_orientation[:, :, 0].sin(), |
|
], |
|
dim=-1, |
|
), |
|
], |
|
dim=-1, |
|
) |
|
|
|
bs, M, P, C = polygon_feature.shape |
|
valid_mask = valid_mask.view(bs * M, P) |
|
polygon_feature = polygon_feature.reshape(bs * M, P, C) |
|
|
|
x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1) |
|
|
|
x_type = self.type_emb(polygon_type) |
|
|
|
x_polygon += x_type |
|
|
|
return x_polygon |
|
|
|
|
|
class AgentEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
agent_channel=9, |
|
dim=128, |
|
) -> None: |
|
super().__init__() |
|
self.dim = dim |
|
self.first_mlp = nn.Sequential( |
|
nn.Linear(agent_channel, 128), |
|
nn.BatchNorm1d(128), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(128, 256), |
|
) |
|
self.second_mlp = nn.Sequential( |
|
nn.Linear(256, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(256, self.dim), |
|
) |
|
self.type_emb = nn.Embedding(4, dim) |
|
|
|
def forward(self, data): |
|
category = data["categories"].long() |
|
agent_feature = data["states"] |
|
valid_mask = data["valid_mask"] |
|
|
|
bs, A, _ = agent_feature.shape |
|
agent_feature = self.second_mlp( |
|
self.first_mlp( |
|
agent_feature[valid_mask].to(torch.float32) |
|
) |
|
) |
|
res = torch.zeros(bs, A, self.dim, |
|
device=agent_feature.device, |
|
dtype=agent_feature.dtype) |
|
res[valid_mask] = agent_feature |
|
|
|
x_type = self.type_emb(category) |
|
return res + x_type |
|
|
|
class Mlp(nn.Module): |
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
drop=0.0, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop) |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop2 = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop1(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|
|
|
|
class CustomTransformerEncoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4.0, |
|
qkv_bias=False, |
|
drop=0.0, |
|
attn_drop=0.0, |
|
drop_path=0.0, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = torch.nn.MultiheadAttention( |
|
dim, |
|
num_heads=num_heads, |
|
add_bias_kv=qkv_bias, |
|
dropout=attn_drop, |
|
batch_first=True, |
|
) |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=int(dim * mlp_ratio), |
|
act_layer=act_layer, |
|
drop=drop, |
|
) |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward( |
|
self, |
|
src, |
|
mask: Optional[Tensor] = None, |
|
key_padding_mask: Optional[Tensor] = None, |
|
): |
|
src2 = self.norm1(src) |
|
src2 = self.attn( |
|
query=src2, |
|
key=src2, |
|
value=src2, |
|
attn_mask=mask, |
|
key_padding_mask=key_padding_mask, |
|
)[0] |
|
src = src + self.drop_path1(src2) |
|
src = src + self.drop_path2(self.mlp(self.norm2(src))) |
|
return src |
|
|