|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from navsim.agents.hydra.hydra_backbone_pe import HydraBackbonePE |
|
from navsim.agents.hydra.hydra_config import HydraConfig |
|
from navsim.agents.transfuser.transfuser_model import AgentHead |
|
from navsim.agents.utils.attn import MemoryEffTransformer |
|
from navsim.agents.utils.nerf import nerf_positional_encoding |
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config |
|
|
|
|
|
class HydraModelOffset(nn.Module): |
|
def __init__(self, config: HydraConfig): |
|
super().__init__() |
|
|
|
self._query_splits = [ |
|
config.num_bounding_boxes, |
|
] |
|
|
|
self._config = config |
|
assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin'] |
|
if config.backbone_type == 'eva': |
|
raise ValueError(f'{config.backbone_type} not supported') |
|
elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or \ |
|
config.backbone_type == 'swin' or config.backbone_type == 'vit': |
|
self._backbone = HydraBackbonePE(config) |
|
|
|
img_num = 2 if config.use_back_view else 1 |
|
self._keyval_embedding = nn.Embedding( |
|
config.img_vert_anchors * config.img_horz_anchors * img_num, config.tf_d_model |
|
) |
|
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) |
|
|
|
|
|
self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1) |
|
self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model) |
|
|
|
tf_decoder_layer = nn.TransformerDecoderLayer( |
|
d_model=config.tf_d_model, |
|
nhead=config.tf_num_head, |
|
dim_feedforward=config.tf_d_ffn, |
|
dropout=config.tf_dropout, |
|
batch_first=True, |
|
) |
|
|
|
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) |
|
self._agent_head = AgentHead( |
|
num_agents=config.num_bounding_boxes, |
|
d_ffn=config.tf_d_ffn, |
|
d_model=config.tf_d_model, |
|
) |
|
|
|
self._trajectory_head = HydraTrajHead( |
|
num_poses=config.trajectory_sampling.num_poses, |
|
d_ffn=config.tf_d_ffn, |
|
d_model=config.tf_d_model, |
|
nhead=config.vadv2_head_nhead, |
|
nlayers=config.vadv2_head_nlayers, |
|
vocab_path=config.vocab_path, |
|
config=config |
|
) |
|
|
|
self.vocab = nn.Parameter( |
|
torch.from_numpy(np.load(config.vocab_path)), |
|
requires_grad=False |
|
) |
|
self.planner_head = nn.Sequential( |
|
nn.Linear(config.tf_d_model, config.tf_d_ffn), |
|
|
|
nn.ReLU(), |
|
nn.Linear(config.tf_d_ffn, config.tf_d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(config.tf_d_ffn, config.trajectory_sampling.num_poses * 3), |
|
) |
|
self._pos_embed = nn.Sequential( |
|
nn.Linear(config.trajectory_sampling.num_poses * 3, config.tf_d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(config.tf_d_ffn, config.tf_d_model), |
|
) |
|
self._encoder = MemoryEffTransformer( |
|
d_model=config.tf_d_model, |
|
nhead=config.vadv2_head_nhead, |
|
dim_feedforward=config.tf_d_model * 4, |
|
dropout=0.0 |
|
) |
|
self._transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
config.tf_d_model, config.vadv2_head_nhead, config.tf_d_ffn, |
|
dropout=0.0, batch_first=True |
|
), config.vadv2_head_nlayers |
|
) |
|
|
|
def img_feat_blc(self, camera_feature): |
|
img_features = self._backbone(camera_feature) |
|
img_features = self.downscale_layer(img_features).flatten(-2, -1) |
|
img_features = img_features.permute(0, 2, 1) |
|
return img_features |
|
|
|
def forward(self, features: Dict[str, torch.Tensor], |
|
interpolated_traj=None) -> Dict[str, torch.Tensor]: |
|
camera_feature: torch.Tensor = features["camera_feature"] |
|
status_feature: torch.Tensor = features["status_feature"] |
|
if isinstance(camera_feature, list): |
|
camera_feature = camera_feature[-1] |
|
|
|
|
|
|
|
|
|
|
|
batch_size = status_feature.shape[0] |
|
|
|
img_features = self.img_feat_blc(camera_feature) |
|
if self._config.use_back_view: |
|
img_features_back = self.img_feat_blc(features["camera_feature_back"]) |
|
img_features = torch.cat([img_features, img_features_back], 1) |
|
|
|
if self._config.num_ego_status == 1 and status_feature.shape[1] == 32: |
|
status_encoding = self._status_encoding(status_feature[:, :8]) |
|
else: |
|
status_encoding = self._status_encoding(status_feature) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keyval = img_features |
|
keyval += self._keyval_embedding.weight[None, ...] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) |
|
agents_query = self._tf_decoder(query, keyval) |
|
|
|
output: Dict[str, torch.Tensor] = {} |
|
trajectory = self._trajectory_head(keyval, status_encoding, interpolated_traj) |
|
output.update(trajectory) |
|
agents = self._agent_head(agents_query) |
|
output.update(agents) |
|
|
|
return output |
|
|
|
|
|
class HydraTrajHead(nn.Module): |
|
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str, |
|
nhead: int, nlayers: int, config: Vadv2Config = None |
|
): |
|
super().__init__() |
|
self._num_poses = num_poses |
|
self.transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model, nhead, d_ffn, |
|
dropout=0.0, batch_first=True |
|
), nlayers |
|
) |
|
self.regression_transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model, nhead, d_ffn, |
|
dropout=0.0, batch_first=True |
|
), nlayers |
|
) |
|
self.imi_transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model, nhead, d_ffn, |
|
dropout=0.0, batch_first=True |
|
), nlayers |
|
) |
|
|
|
self.offset_xy_bound = 1 |
|
self.offset_heading_bound = 0.01 |
|
self.offset_xy = nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, num_poses * 2 // 2), |
|
nn.Tanh() |
|
) |
|
self.offset_heading = nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, num_poses * 1 // 2), |
|
nn.Tanh() |
|
) |
|
self.imi_regression_head = nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
) |
|
self.vocab = nn.Parameter( |
|
torch.from_numpy(np.load(vocab_path)), |
|
requires_grad=False |
|
) |
|
|
|
self.heads = nn.ModuleDict({ |
|
'noc': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'da': |
|
nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'ttc': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'comfort': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'progress': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'imi': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
) |
|
}) |
|
|
|
self.inference_imi_weight = config.inference_imi_weight |
|
self.inference_da_weight = config.inference_da_weight |
|
self.normalize_vocab_pos = config.normalize_vocab_pos |
|
if self.normalize_vocab_pos: |
|
self.encoder = MemoryEffTransformer( |
|
d_model=d_model, |
|
nhead=nhead, |
|
dim_feedforward=d_model * 4, |
|
dropout=0.0 |
|
) |
|
self.use_nerf = config.use_nerf |
|
|
|
if self.use_nerf: |
|
self.pos_embed = nn.Sequential( |
|
nn.Linear(1040, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_model), |
|
) |
|
else: |
|
self.pos_embed = nn.Sequential( |
|
nn.Linear(num_poses * 3, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_model), |
|
) |
|
self.mlp_pos_embed = nn.Sequential( |
|
nn.Linear(num_poses * 3, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_model), |
|
) |
|
self.encoder_offset = MemoryEffTransformer( |
|
d_model=d_model, |
|
nhead=nhead, |
|
dim_feedforward=d_model * 4, |
|
dropout=0.0 |
|
) |
|
|
|
def forward(self, bev_feature, status_encoding, interpolated_traj=None) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
|
|
vocab = self.vocab.data |
|
L, HORIZON, _ = vocab.shape |
|
B = bev_feature.shape[0] |
|
|
|
if self.use_nerf: |
|
vocab = torch.cat( |
|
[ |
|
nerf_positional_encoding(vocab[..., :2]), |
|
torch.cos(vocab[..., -1])[..., None], |
|
torch.sin(vocab[..., -1])[..., None], |
|
], dim=-1 |
|
) |
|
|
|
if self.normalize_vocab_pos: |
|
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None] |
|
embedded_vocab = self.encoder(embedded_vocab).repeat(B, 1, 1) |
|
else: |
|
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1) |
|
tr_out = self.transformer(embedded_vocab, bev_feature) |
|
dist_status = tr_out + status_encoding.unsqueeze(1) |
|
result = {} |
|
|
|
for k, head in self.heads.items(): |
|
if k == 'imi': |
|
result[k] = head(dist_status).squeeze(-1) |
|
else: |
|
result[k] = head(dist_status).squeeze(-1).sigmoid() |
|
scores = ( |
|
0.05 * result['imi'].softmax(-1).log() + |
|
0.5 * result['noc'].log() + |
|
0.5 * result['da'].log() + |
|
8.0 * (5 * result['ttc'] + 2 * result['comfort'] + 5 * result['progress']).log() |
|
) |
|
selected_indices_raw = scores.argmax(1) |
|
|
|
K = 64 |
|
_, top_512_indices = torch.topk(scores, K, dim=1, largest=True) |
|
batch_indices = torch.arange(embedded_vocab.size(0))[..., None].repeat(1, K).to(embedded_vocab.device) |
|
embedded_vocab_512 = embedded_vocab[batch_indices, top_512_indices] |
|
|
|
|
|
tr_out_512 = ( |
|
self.regression_transformer(embedded_vocab_512, bev_feature) + |
|
status_encoding.unsqueeze(1) |
|
) |
|
|
|
|
|
offset_512_xy = self.offset_xy(tr_out_512) |
|
offset_512_heading = self.offset_heading(tr_out_512) |
|
offset_512 = torch.cat([ |
|
offset_512_xy.view(B, K, HORIZON // 2, 2) * self.offset_xy_bound, |
|
offset_512_heading.view(B, K, HORIZON // 2, 1) * self.offset_heading_bound |
|
], -1).contiguous() |
|
|
|
|
|
padded_offset_512 = torch.cat([ |
|
torch.zeros_like(offset_512), |
|
offset_512 |
|
], dim=2) |
|
|
|
final_traj = vocab[None, ...].repeat(B, 1, 1, 1)[batch_indices, top_512_indices] + padded_offset_512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result["trajectory_offset"] = final_traj |
|
|
|
|
|
selected_indices_expanded = selected_indices_raw[:, None].expand(-1, top_512_indices.size(1)) |
|
|
|
matches = (top_512_indices == selected_indices_expanded).int() |
|
|
|
positions = torch.argmax(matches, dim=1) |
|
result["trajectory_offset"] = final_traj |
|
pred_traj = final_traj[ |
|
torch.arange(final_traj.size(0)), |
|
positions |
|
] |
|
result["trajectory"] = pred_traj |
|
|
|
result["trajectory_vocab"] = self.vocab.data |
|
return result |
|
|