lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
from typing import Optional
import torch
import torch.nn as nn
from timm.models.layers import DropPath
from torch import Tensor
class PointsEncoder(nn.Module):
def __init__(self, feat_channel, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_mlp = nn.Sequential(
nn.Linear(feat_channel, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 256),
)
self.second_mlp = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, self.encoder_channel),
)
def forward(self, x, mask=None):
"""
x : B M 3
mask: B M
-----------------
feature_global : B C
"""
bs, n, _ = x.shape
device = x.device
x_valid = self.first_mlp(x[mask].to(torch.float32)) # B n 256
x_features = torch.zeros(bs, n, 256, device=device)
x_features[mask] = x_valid
pooled_feature = x_features.max(dim=1)[0]
x_features = torch.cat(
[x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1
)
x_features_valid = self.second_mlp(x_features[mask])
res = torch.zeros(bs, n, self.encoder_channel, device=device)
res[mask] = x_features_valid
res = res.max(dim=1)[0]
return res
class MapEncoder(nn.Module):
def __init__(
self,
polygon_channel=6,
dim=128,
) -> None:
super().__init__()
self.dim = dim
self.polygon_encoder = PointsEncoder(polygon_channel, dim)
self.speed_limit_emb = nn.Sequential(
nn.Linear(1, dim), nn.ReLU(), nn.Linear(dim, dim)
)
self.type_emb = nn.Embedding(3, dim)
def forward(self, data) -> torch.Tensor:
polygon_center = data["polygon_center"]
polygon_type = data["polygon_type"].long()
point_position = data["point_position"]
point_vector = data["point_vector"]
point_orientation = data["point_orientation"]
valid_mask = data["valid_mask"]
polygon_feature = torch.cat(
[
point_position[:, :, 0] - polygon_center[..., None, :2],
point_vector[:, :, 0],
torch.stack(
[
point_orientation[:, :, 0].cos(),
point_orientation[:, :, 0].sin(),
],
dim=-1,
),
],
dim=-1,
)
bs, M, P, C = polygon_feature.shape
valid_mask = valid_mask.view(bs * M, P)
polygon_feature = polygon_feature.reshape(bs * M, P, C)
x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1)
x_type = self.type_emb(polygon_type)
x_polygon += x_type
return x_polygon
class AgentEncoder(nn.Module):
def __init__(
self,
agent_channel=9,
dim=128,
) -> None:
super().__init__()
self.dim = dim
self.first_mlp = nn.Sequential(
nn.Linear(agent_channel, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 256),
)
self.second_mlp = nn.Sequential(
nn.Linear(256, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, self.dim),
)
self.type_emb = nn.Embedding(4, dim)
def forward(self, data):
category = data["categories"].long()
agent_feature = data["states"]
valid_mask = data["valid_mask"]
bs, A, _ = agent_feature.shape
agent_feature = self.second_mlp(
self.first_mlp(
agent_feature[valid_mask].to(torch.float32)
)
) # B, A, C
res = torch.zeros(bs, A, self.dim,
device=agent_feature.device,
dtype=agent_feature.dtype)
res[valid_mask] = agent_feature
x_type = self.type_emb(category)
return res + x_type
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class CustomTransformerEncoderLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = torch.nn.MultiheadAttention(
dim,
num_heads=num_heads,
add_bias_kv=qkv_bias,
dropout=attn_drop,
batch_first=True,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(
self,
src,
mask: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
):
src2 = self.norm1(src)
src2 = self.attn(
query=src2,
key=src2,
value=src2,
attn_mask=mask,
key_padding_mask=key_padding_mask,
)[0]
src = src + self.drop_path1(src2)
src = src + self.drop_path2(self.mlp(self.norm2(src)))
return src