from typing import Dict import numpy as np import torch import torch.nn as nn from navsim.agents.transfuser.transfuser_backbone import TransfuserBackbone from navsim.agents.transfuser.transfuser_model import AgentHead from navsim.agents.vadv2.vadv2_config import Vadv2Config class Vadv2Model(nn.Module): def __init__(self, config: Vadv2Config): super().__init__() self._query_splits = [ config.num_bounding_boxes, ] self._config = config self._backbone = TransfuserBackbone(config) self._keyval_embedding = nn.Embedding( 8 ** 2, config.tf_d_model ) # 8x8 feature grid + trajectory self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) # usually, the BEV features are variable in size. self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1) # todo drop ego status like plantf self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model) self._bev_semantic_head = nn.Sequential( nn.Conv2d( config.bev_features_channels, config.bev_features_channels, kernel_size=(3, 3), stride=1, padding=(1, 1), bias=True, ), nn.ReLU(inplace=True), nn.Conv2d( config.bev_features_channels, config.num_bev_classes, kernel_size=(1, 1), stride=1, padding=0, bias=True, ), nn.Upsample( size=(config.lidar_resolution_height // 2, config.lidar_resolution_width), mode="bilinear", align_corners=False, ), ) tf_decoder_layer = nn.TransformerDecoderLayer( d_model=config.tf_d_model, nhead=config.tf_num_head, dim_feedforward=config.tf_d_ffn, dropout=config.tf_dropout, batch_first=True, ) self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) self._agent_head = AgentHead( num_agents=config.num_bounding_boxes, d_ffn=config.tf_d_ffn, d_model=config.tf_d_model, ) self._trajectory_head = Vadv2Head( num_poses=config.trajectory_sampling.num_poses, d_ffn=config.tf_d_ffn, nhead=config.vadv2_head_nhead, use_ori=config.type == 'ori', # cb_weight_path=config.cb_weight_path, # cb_weight_beta=config.cb_weight_beta, nlayers=config.vadv2_head_nlayers, d_model=config.tf_d_model, vocab_path=config.vocab_path ) def forward(self, features: Dict[str, torch.Tensor], interpolated_traj=None) -> Dict[str, torch.Tensor]: # Todo egostatus camera_feature: torch.Tensor = features["camera_feature"] lidar_feature: torch.Tensor = features["lidar_feature"] status_feature: torch.Tensor = features["status_feature"] batch_size = status_feature.shape[0] bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature) bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1) bev_feature = bev_feature.permute(0, 2, 1) status_encoding = self._status_encoding(status_feature) keyval = bev_feature keyval += self._keyval_embedding.weight[None, ...] query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) agents_query = self._tf_decoder(query, keyval) bev_semantic_map = self._bev_semantic_head(bev_feature_upscale) output: Dict[str, torch.Tensor] = {"bev_semantic_map": bev_semantic_map} trajectory = self._trajectory_head(keyval, status_encoding, interpolated_traj) output.update(trajectory) agents = self._agent_head(agents_query) output.update(agents) return output class Vadv2Head(nn.Module): def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str, # cb_weight_path: str, # cb_weight_beta: float, nhead: int, nlayers: int, use_ori=False): super(Vadv2Head, self).__init__() self.use_ori = use_ori self._num_poses = num_poses self.transformer = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model, nhead, d_ffn, dropout=0.0, batch_first=True ), nlayers ) self.vocab = nn.Parameter( torch.from_numpy(np.load(vocab_path)), requires_grad=False ) # self.cb_weight = torch.from_numpy(np.load(cb_weight_path)) # self.cb_weight = (1 - torch.tensor([cb_weight_beta])) / (1 - torch.tensor([cb_weight_beta]).pow(self.cb_weight)) self.mlp = nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ) # todo explore sinusoidal embedding self.pos_embed = nn.Sequential( nn.Linear(num_poses * 3, d_ffn), nn.ReLU(), nn.Linear(d_ffn, d_model), ) def forward(self, bev_feature, status_encoding, interpolated_traj) -> Dict[str, torch.Tensor]: # todo sinusoidal embedding # vocab: 4096, 40, 3 # bev_feature: B, 32, C # embedded_vocab: B, 4096, C N_VOCAB = self.vocab.data.shape[0] vocab = self.vocab.data L, HORIZON, _ = vocab.shape B = bev_feature.shape[0] if self.use_ori and interpolated_traj is not None: vocab = torch.cat([vocab, interpolated_traj.to(vocab.dtype)], dim=0).contiguous() L += B embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1) dist = self.mlp(self.transformer(embedded_vocab, bev_feature) + status_encoding.unsqueeze(1)) # selected_indices: B, selected_indices = dist[:, :N_VOCAB].argmax(1).squeeze(1) return { "trajectory": self.vocab.data[selected_indices], "trajectory_distribution": dist.squeeze(-1), "trajectory_vocab": vocab, # "cb_weight": self.cb_weight }