|
import math |
|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from diffusers import DDIMScheduler |
|
|
|
from navsim.agents.dm.backbone import DMBackbone |
|
from navsim.agents.dm.dm_config import DMConfig |
|
from navsim.agents.dm.utils import VerletStandardizer |
|
from navsim.agents.transfuser.transfuser_model import AgentHead |
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
device = x.device |
|
half_dim = self.dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
emb = x[:, None] * emb[None, :] |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
return emb |
|
|
|
|
|
class DMModel(nn.Module): |
|
def __init__(self, config: DMConfig): |
|
super().__init__() |
|
|
|
self._query_splits = [ |
|
config.num_bounding_boxes, |
|
] |
|
|
|
self._config = config |
|
assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin'] |
|
if config.backbone_type == 'eva': |
|
raise ValueError(f'{config.backbone_type} not supported') |
|
elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or \ |
|
config.backbone_type == 'swin' or config.backbone_type == 'vit': |
|
self._backbone = DMBackbone(config) |
|
|
|
img_num = 2 if config.use_back_view else 1 |
|
self._keyval_embedding = nn.Embedding( |
|
config.img_vert_anchors * config.img_horz_anchors * img_num, config.tf_d_model |
|
) |
|
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) |
|
|
|
|
|
self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1) |
|
self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model) |
|
|
|
tf_decoder_layer = nn.TransformerDecoderLayer( |
|
d_model=config.tf_d_model, |
|
nhead=config.tf_num_head, |
|
dim_feedforward=config.tf_d_ffn, |
|
dropout=config.tf_dropout, |
|
batch_first=True, |
|
) |
|
|
|
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) |
|
self._agent_head = AgentHead( |
|
num_agents=config.num_bounding_boxes, |
|
d_ffn=config.tf_d_ffn, |
|
d_model=config.tf_d_model, |
|
) |
|
|
|
self._trajectory_head = DMTrajHead( |
|
num_poses=config.trajectory_sampling.num_poses, |
|
d_ffn=config.tf_d_ffn, |
|
d_model=config.tf_d_model, |
|
nhead=config.vadv2_head_nhead, |
|
nlayers=config.vadv2_head_nlayers, |
|
vocab_path=config.vocab_path, |
|
config=config |
|
) |
|
|
|
def img_feat_blc(self, camera_feature): |
|
img_features = self._backbone(camera_feature) |
|
img_features = self.downscale_layer(img_features).flatten(-2, -1) |
|
img_features = img_features.permute(0, 2, 1) |
|
return img_features |
|
|
|
def forward(self, features: Dict[str, torch.Tensor], |
|
interpolated_traj=None) -> Dict[str, torch.Tensor]: |
|
camera_feature: torch.Tensor = features["camera_feature"] |
|
status_feature: torch.Tensor = features["status_feature"] |
|
if isinstance(camera_feature, list): |
|
camera_feature = camera_feature[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = status_feature.shape[0] |
|
|
|
img_features = self.img_feat_blc(camera_feature) |
|
if self._config.use_back_view: |
|
img_features_back = self.img_feat_blc(features["camera_feature_back"]) |
|
img_features = torch.cat([img_features, img_features_back], 1) |
|
|
|
if self._config.num_ego_status == 1 and status_feature.shape[1] == 32: |
|
status_encoding = self._status_encoding(status_feature[:, :8]) |
|
else: |
|
status_encoding = self._status_encoding(status_feature) |
|
|
|
keyval = img_features |
|
keyval += self._keyval_embedding.weight[None, ...] |
|
|
|
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) |
|
agents_query = self._tf_decoder(query, keyval) |
|
|
|
output: Dict[str, torch.Tensor] = {} |
|
trajectory = self._trajectory_head(keyval, status_encoding, features['history_waypoints']) |
|
output.update(trajectory) |
|
agents = self._agent_head(agents_query) |
|
output.update(agents) |
|
|
|
return output |
|
|
|
|
|
class DMTrajHead(nn.Module): |
|
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str, |
|
nhead: int, nlayers: int, config: DMConfig = None |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.config = config |
|
self._num_poses = num_poses |
|
self.transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model, nhead, d_ffn, |
|
dropout=0.0, batch_first=True |
|
), nlayers |
|
) |
|
self.vocab = nn.Parameter( |
|
torch.from_numpy(np.load(vocab_path)), |
|
requires_grad=False |
|
) |
|
self.H = config.trajectory_sampling.num_poses |
|
self.T = config.T |
|
self.standardizer = VerletStandardizer() |
|
self.decoder_mlp = nn.Sequential( |
|
nn.Linear(self.d_model, self.d_model), |
|
nn.ReLU(), |
|
nn.Linear(self.d_model, self.d_model), |
|
nn.ReLU(), |
|
nn.Linear(self.d_model, self.H * 3) |
|
) |
|
self.encoder_mlp = nn.Sequential( |
|
nn.Linear(self.H * 3, self.d_model), |
|
nn.ReLU(), |
|
nn.Linear(self.d_model, self.d_model), |
|
) |
|
self.sigma_encoder = nn.Sequential( |
|
SinusoidalPosEmb(self.d_model), |
|
) |
|
|
|
self.scheduler = DDIMScheduler( |
|
num_train_timesteps=self.T, |
|
beta_schedule='scaled_linear', |
|
prediction_type='epsilon', |
|
) |
|
self.scheduler.set_timesteps(self.T) |
|
|
|
def denoise(self, ego_trajectory, env_features, status_encoding, timesteps): |
|
B = ego_trajectory.shape[0] |
|
ego_trajectory = ego_trajectory.reshape(B, -1).to(torch.float32) |
|
sigma = timesteps.reshape(-1, 1) |
|
if sigma.numel() == 1: |
|
sigma = sigma.repeat(B, 1) |
|
sigma = sigma.float() / self.T |
|
sigma_embeddings = self.sigma_encoder(sigma).squeeze(1) |
|
|
|
ego_emb = self.encoder_mlp(ego_trajectory) + status_encoding + sigma_embeddings |
|
ego_attn = self.transformer(ego_emb[:, None], env_features) |
|
out = self.decoder_mlp(ego_attn).reshape(B, -1) |
|
return out |
|
|
|
def forward(self, bev_feature, status_encoding, history_waypoints) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
B = bev_feature.shape[0] |
|
result = {} |
|
|
|
if not self.config.is_training: |
|
ego_trajectory = torch.randn((B, self.H * 3), |
|
device=bev_feature.device) |
|
timesteps = self.scheduler.timesteps |
|
residual = torch.zeros_like(ego_trajectory) |
|
for t in timesteps: |
|
with torch.no_grad(): |
|
residual += self.denoise( |
|
ego_trajectory, |
|
bev_feature, |
|
status_encoding, |
|
t.to(ego_trajectory.device) |
|
) |
|
|
|
out = self.scheduler.step(residual, t, ego_trajectory) |
|
ego_trajectory = out.prev_sample |
|
|
|
ego_trajectory = self.standardizer.untransform_features(ego_trajectory, history_waypoints) |
|
result["trajectory"] = ego_trajectory.reshape(B, self.H, 3) |
|
|
|
result['imi'], result['noc'], result['da'], result['ttc'], result['comfort'], result['progress'] = ( |
|
torch.ones((B, 4096)) for _ in range(6) |
|
) |
|
result['history_waypoints'] = history_waypoints |
|
result['env_features'] = bev_feature |
|
result['status_encoding'] = status_encoding |
|
return result |
|
|