import math from typing import Dict import numpy as np import torch import torch.nn as nn from diffusers import DDIMScheduler from navsim.agents.dm.backbone import DMBackbone from navsim.agents.dm.dm_config import DMConfig from navsim.agents.dm.utils import VerletStandardizer from navsim.agents.transfuser.transfuser_model import AgentHead class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class DMModel(nn.Module): def __init__(self, config: DMConfig): super().__init__() self._query_splits = [ config.num_bounding_boxes, ] self._config = config assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin'] if config.backbone_type == 'eva': raise ValueError(f'{config.backbone_type} not supported') elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or \ config.backbone_type == 'swin' or config.backbone_type == 'vit': self._backbone = DMBackbone(config) img_num = 2 if config.use_back_view else 1 self._keyval_embedding = nn.Embedding( config.img_vert_anchors * config.img_horz_anchors * img_num, config.tf_d_model ) # 8x8 feature grid + trajectory self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) # usually, the BEV features are variable in size. self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1) self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model) tf_decoder_layer = nn.TransformerDecoderLayer( d_model=config.tf_d_model, nhead=config.tf_num_head, dim_feedforward=config.tf_d_ffn, dropout=config.tf_dropout, batch_first=True, ) self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) self._agent_head = AgentHead( num_agents=config.num_bounding_boxes, d_ffn=config.tf_d_ffn, d_model=config.tf_d_model, ) self._trajectory_head = DMTrajHead( num_poses=config.trajectory_sampling.num_poses, d_ffn=config.tf_d_ffn, d_model=config.tf_d_model, nhead=config.vadv2_head_nhead, nlayers=config.vadv2_head_nlayers, vocab_path=config.vocab_path, config=config ) def img_feat_blc(self, camera_feature): img_features = self._backbone(camera_feature) img_features = self.downscale_layer(img_features).flatten(-2, -1) img_features = img_features.permute(0, 2, 1) return img_features def forward(self, features: Dict[str, torch.Tensor], interpolated_traj=None) -> Dict[str, torch.Tensor]: camera_feature: torch.Tensor = features["camera_feature"] status_feature: torch.Tensor = features["status_feature"] if isinstance(camera_feature, list): camera_feature = camera_feature[-1] # todo temp fix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # status_feature[:, 0] = 0.0 # status_feature[:, 1] = 1.0 # status_feature[:, 2] = 0.0 # status_feature[:, 3] = 0.0 batch_size = status_feature.shape[0] img_features = self.img_feat_blc(camera_feature) if self._config.use_back_view: img_features_back = self.img_feat_blc(features["camera_feature_back"]) img_features = torch.cat([img_features, img_features_back], 1) if self._config.num_ego_status == 1 and status_feature.shape[1] == 32: status_encoding = self._status_encoding(status_feature[:, :8]) else: status_encoding = self._status_encoding(status_feature) keyval = img_features keyval += self._keyval_embedding.weight[None, ...] query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) agents_query = self._tf_decoder(query, keyval) output: Dict[str, torch.Tensor] = {} trajectory = self._trajectory_head(keyval, status_encoding, features['history_waypoints']) output.update(trajectory) agents = self._agent_head(agents_query) output.update(agents) return output class DMTrajHead(nn.Module): def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str, nhead: int, nlayers: int, config: DMConfig = None ): super().__init__() self.d_model = d_model self.config = config self._num_poses = num_poses self.transformer = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model, nhead, d_ffn, dropout=0.0, batch_first=True ), nlayers ) self.vocab = nn.Parameter( torch.from_numpy(np.load(vocab_path)), requires_grad=False ) self.H = config.trajectory_sampling.num_poses self.T = config.T self.standardizer = VerletStandardizer() self.decoder_mlp = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.ReLU(), nn.Linear(self.d_model, self.d_model), nn.ReLU(), nn.Linear(self.d_model, self.H * 3) ) self.encoder_mlp = nn.Sequential( nn.Linear(self.H * 3, self.d_model), nn.ReLU(), nn.Linear(self.d_model, self.d_model), ) self.sigma_encoder = nn.Sequential( SinusoidalPosEmb(self.d_model), ) self.scheduler = DDIMScheduler( num_train_timesteps=self.T, beta_schedule='scaled_linear', prediction_type='epsilon', ) self.scheduler.set_timesteps(self.T) def denoise(self, ego_trajectory, env_features, status_encoding, timesteps): B = ego_trajectory.shape[0] ego_trajectory = ego_trajectory.reshape(B, -1).to(torch.float32) sigma = timesteps.reshape(-1, 1) if sigma.numel() == 1: sigma = sigma.repeat(B, 1) sigma = sigma.float() / self.T sigma_embeddings = self.sigma_encoder(sigma).squeeze(1) ego_emb = self.encoder_mlp(ego_trajectory) + status_encoding + sigma_embeddings ego_attn = self.transformer(ego_emb[:, None], env_features) out = self.decoder_mlp(ego_attn).reshape(B, -1) return out def forward(self, bev_feature, status_encoding, history_waypoints) -> Dict[str, torch.Tensor]: # todo sinusoidal embedding # vocab: 4096, 40, 3 # bev_feature: B, 32, C # embedded_vocab: B, 4096, C B = bev_feature.shape[0] result = {} if not self.config.is_training: ego_trajectory = torch.randn((B, self.H * 3), device=bev_feature.device) timesteps = self.scheduler.timesteps residual = torch.zeros_like(ego_trajectory) for t in timesteps: with torch.no_grad(): residual += self.denoise( ego_trajectory, bev_feature, status_encoding, t.to(ego_trajectory.device) ) out = self.scheduler.step(residual, t, ego_trajectory) ego_trajectory = out.prev_sample ego_trajectory = self.standardizer.untransform_features(ego_trajectory, history_waypoints) result["trajectory"] = ego_trajectory.reshape(B, self.H, 3) result['imi'], result['noc'], result['da'], result['ttc'], result['comfort'], result['progress'] = ( torch.ones((B, 4096)) for _ in range(6) ) result['history_waypoints'] = history_waypoints result['env_features'] = bev_feature result['status_encoding'] = status_encoding return result