|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from navsim.agents.transfuser.transfuser_backbone import TransfuserBackbone
|
|
from navsim.agents.transfuser.transfuser_model import AgentHead
|
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config
|
|
|
|
|
|
class Vadv2Model(nn.Module):
|
|
def __init__(self, config: Vadv2Config):
|
|
super().__init__()
|
|
|
|
self._query_splits = [
|
|
config.num_bounding_boxes,
|
|
]
|
|
|
|
self._config = config
|
|
self._backbone = TransfuserBackbone(config)
|
|
|
|
self._keyval_embedding = nn.Embedding(
|
|
8 ** 2, config.tf_d_model
|
|
)
|
|
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)
|
|
|
|
|
|
self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1)
|
|
|
|
self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model)
|
|
|
|
self._bev_semantic_head = nn.Sequential(
|
|
nn.Conv2d(
|
|
config.bev_features_channels,
|
|
config.bev_features_channels,
|
|
kernel_size=(3, 3),
|
|
stride=1,
|
|
padding=(1, 1),
|
|
bias=True,
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(
|
|
config.bev_features_channels,
|
|
config.num_bev_classes,
|
|
kernel_size=(1, 1),
|
|
stride=1,
|
|
padding=0,
|
|
bias=True,
|
|
),
|
|
nn.Upsample(
|
|
size=(config.lidar_resolution_height // 2, config.lidar_resolution_width),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
),
|
|
)
|
|
|
|
tf_decoder_layer = nn.TransformerDecoderLayer(
|
|
d_model=config.tf_d_model,
|
|
nhead=config.tf_num_head,
|
|
dim_feedforward=config.tf_d_ffn,
|
|
dropout=config.tf_dropout,
|
|
batch_first=True,
|
|
)
|
|
|
|
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)
|
|
self._agent_head = AgentHead(
|
|
num_agents=config.num_bounding_boxes,
|
|
d_ffn=config.tf_d_ffn,
|
|
d_model=config.tf_d_model,
|
|
)
|
|
|
|
self._trajectory_head = Vadv2Head(
|
|
num_poses=config.trajectory_sampling.num_poses,
|
|
d_ffn=config.tf_d_ffn,
|
|
nhead=config.vadv2_head_nhead,
|
|
use_ori=config.type == 'ori',
|
|
|
|
|
|
nlayers=config.vadv2_head_nlayers,
|
|
d_model=config.tf_d_model,
|
|
vocab_path=config.vocab_path
|
|
)
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor],
|
|
interpolated_traj=None) -> Dict[str, torch.Tensor]:
|
|
|
|
camera_feature: torch.Tensor = features["camera_feature"]
|
|
lidar_feature: torch.Tensor = features["lidar_feature"]
|
|
status_feature: torch.Tensor = features["status_feature"]
|
|
|
|
batch_size = status_feature.shape[0]
|
|
|
|
bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature)
|
|
|
|
bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1)
|
|
bev_feature = bev_feature.permute(0, 2, 1)
|
|
status_encoding = self._status_encoding(status_feature)
|
|
|
|
keyval = bev_feature
|
|
keyval += self._keyval_embedding.weight[None, ...]
|
|
|
|
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
|
|
agents_query = self._tf_decoder(query, keyval)
|
|
|
|
bev_semantic_map = self._bev_semantic_head(bev_feature_upscale)
|
|
|
|
output: Dict[str, torch.Tensor] = {"bev_semantic_map": bev_semantic_map}
|
|
trajectory = self._trajectory_head(keyval, status_encoding, interpolated_traj)
|
|
output.update(trajectory)
|
|
|
|
agents = self._agent_head(agents_query)
|
|
output.update(agents)
|
|
|
|
return output
|
|
|
|
|
|
class Vadv2Head(nn.Module):
|
|
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str,
|
|
|
|
|
|
nhead: int, nlayers: int, use_ori=False):
|
|
super(Vadv2Head, self).__init__()
|
|
self.use_ori = use_ori
|
|
self._num_poses = num_poses
|
|
self.transformer = nn.TransformerDecoder(
|
|
nn.TransformerDecoderLayer(
|
|
d_model, nhead, d_ffn,
|
|
dropout=0.0, batch_first=True
|
|
), nlayers
|
|
)
|
|
self.vocab = nn.Parameter(
|
|
torch.from_numpy(np.load(vocab_path)),
|
|
requires_grad=False
|
|
)
|
|
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(d_model, d_ffn),
|
|
nn.ReLU(),
|
|
nn.Linear(d_ffn, d_ffn),
|
|
nn.ReLU(),
|
|
nn.Linear(d_ffn, 1),
|
|
)
|
|
|
|
self.pos_embed = nn.Sequential(
|
|
nn.Linear(num_poses * 3, d_ffn),
|
|
nn.ReLU(),
|
|
nn.Linear(d_ffn, d_model),
|
|
)
|
|
|
|
def forward(self, bev_feature, status_encoding, interpolated_traj) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
N_VOCAB = self.vocab.data.shape[0]
|
|
vocab = self.vocab.data
|
|
L, HORIZON, _ = vocab.shape
|
|
B = bev_feature.shape[0]
|
|
if self.use_ori and interpolated_traj is not None:
|
|
vocab = torch.cat([vocab, interpolated_traj.to(vocab.dtype)], dim=0).contiguous()
|
|
L += B
|
|
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1)
|
|
dist = self.mlp(self.transformer(embedded_vocab, bev_feature) + status_encoding.unsqueeze(1))
|
|
|
|
|
|
selected_indices = dist[:, :N_VOCAB].argmax(1).squeeze(1)
|
|
return {
|
|
"trajectory": self.vocab.data[selected_indices],
|
|
"trajectory_distribution": dist.squeeze(-1),
|
|
"trajectory_vocab": vocab,
|
|
|
|
}
|
|
|