from typing import Dict import numpy as np import torch import torch.nn as nn from navsim.agents.hydra_plantf.hydra_plantf_config import HydraPlantfConfig from navsim.agents.hydra_plantf.model_utils import MapEncoder, AgentEncoder, CustomTransformerEncoderLayer from navsim.agents.utils.attn import MemoryEffTransformer from navsim.agents.utils.nerf import nerf_positional_encoding from navsim.agents.vadv2.vadv2_config import Vadv2Config class HydraPlantfModel(nn.Module): def __init__(self, config: HydraPlantfConfig): super().__init__() self._config = config self.map_encoder = MapEncoder( dim=config.tf_d_model, polygon_channel=6 ) self.agent_encoder = AgentEncoder( agent_channel=8, dim=config.tf_d_model, ) # 4 layers self.blocks = nn.ModuleList( CustomTransformerEncoderLayer(dim=config.tf_d_model, num_heads=config.tf_num_head, drop_path=dp) for dp in [x.item() for x in torch.linspace(0, 0.2, config.tf_num_encoder_layers)] ) self.norm = nn.LayerNorm(config.tf_d_model) self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model) self._trajectory_head = HydraTrajPlantfHead( num_poses=config.trajectory_sampling.num_poses, d_ffn=config.tf_d_ffn, d_model=config.tf_d_model, nhead=config.vadv2_head_nhead, nlayers=config.vadv2_head_nlayers, vocab_path=config.vocab_path, config=config ) def forward(self, features: Dict[str, torch.Tensor], interpolated_traj=None) -> Dict[str, torch.Tensor]: status_feature: torch.Tensor = features["status_feature"] if self._config.num_ego_status == 1 and status_feature.shape[1] == 32: status_encoding = self._status_encoding(status_feature[:, :8]) else: status_encoding = self._status_encoding(status_feature) agent_features = self.agent_encoder(features['agent']) map_features = self.map_encoder(features['map']) key_padding_mask = torch.cat([ ~(features['agent']['valid_mask']), ~(features['map']['valid_mask'].any(-1)) ], dim=-1) x = torch.cat([agent_features, map_features], dim=1) for blk in self.blocks: x = blk(x, key_padding_mask=key_padding_mask) keyval = self.norm(x) output: Dict[str, torch.Tensor] = {} trajectory = self._trajectory_head(keyval, status_encoding) output.update(trajectory) return output class HydraTrajPlantfHead(nn.Module): def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str, nhead: int, nlayers: int, config: Vadv2Config = None ): super().__init__() self._num_poses = num_poses self.transformer = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model, nhead, d_ffn, dropout=0.0, batch_first=True ), nlayers ) self.vocab = nn.Parameter( torch.from_numpy(np.load(vocab_path)), requires_grad=False ) self.heads = nn.ModuleDict({ 'noc': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ), 'da': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ), 'ttc': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ), 'comfort': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ), 'progress': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ), 'imi': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(), nn.Linear(d_ffn, d_ffn), nn.ReLU(), nn.Linear(d_ffn, 1), ) }) self.inference_imi_weight = config.inference_imi_weight self.inference_da_weight = config.inference_da_weight self.normalize_vocab_pos = config.normalize_vocab_pos if self.normalize_vocab_pos: self.encoder = MemoryEffTransformer( d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=0.0 ) self.use_nerf = config.use_nerf if self.use_nerf: self.pos_embed = nn.Sequential( nn.Linear(1040, d_ffn), nn.ReLU(), nn.Linear(d_ffn, d_model), ) else: self.pos_embed = nn.Sequential( nn.Linear(num_poses * 3, d_ffn), nn.ReLU(), nn.Linear(d_ffn, d_model), ) def forward(self, bev_feature, status_encoding, interpolated_traj=None) -> Dict[str, torch.Tensor]: # todo sinusoidal embedding # vocab: 4096, 40, 3 # bev_feature: B, 32, C # embedded_vocab: B, 4096, C vocab = self.vocab.data L, HORIZON, _ = vocab.shape B = bev_feature.shape[0] if self.use_nerf: vocab = torch.cat( [ nerf_positional_encoding(vocab[..., :2]), torch.cos(vocab[..., -1])[..., None], torch.sin(vocab[..., -1])[..., None], ], dim=-1 ) if self.normalize_vocab_pos: embedded_vocab = self.pos_embed(vocab.view(L, -1))[None] embedded_vocab = self.encoder(embedded_vocab).repeat(B, 1, 1) else: embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1) tr_out = self.transformer(embedded_vocab, bev_feature) dist_status = tr_out + status_encoding.unsqueeze(1) result = {} # selected_indices: B, for k, head in self.heads.items(): if k == 'imi': result[k] = head(dist_status).squeeze(-1) else: result[k] = head(dist_status).squeeze(-1).sigmoid() scores = ( 0.05 * result['imi'].softmax(-1).log() + 0.5 * result['noc'].log() + 0.5 * result['da'].log() + 8.0 * (5 * result['ttc'] + 2 * result['comfort'] + 5 * result['progress']).log() ) selected_indices = scores.argmax(1) result["trajectory"] = self.vocab.data[selected_indices] result["trajectory_vocab"] = self.vocab.data result["selected_indices"] = selected_indices return result