|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn.bricks.transformer import build_positional_encoding |
|
|
|
from navsim.agents.hydra.hydra_backbone_pe import HydraBackbonePE |
|
from navsim.agents.hydra.hydra_config import HydraConfig |
|
from navsim.agents.utils.attn import MemoryEffTransformer |
|
from navsim.agents.utils.nerf import nerf_positional_encoding |
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config |
|
|
|
|
|
class HydraModelPENoDet(nn.Module): |
|
def __init__(self, config: HydraConfig): |
|
super().__init__() |
|
|
|
self._config = config |
|
assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin'] |
|
if config.backbone_type == 'vit' or config.backbone_type == 'eva': |
|
raise ValueError(f'{config.backbone_type} not supported') |
|
elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or config.backbone_type == 'swin' \ |
|
or config.backbone_type == 'resnet': |
|
self._backbone = HydraBackbonePE(config) |
|
|
|
self._keyval_embedding = nn.Embedding( |
|
config.img_vert_anchors * config.img_horz_anchors, config.tf_d_model |
|
) |
|
|
|
|
|
self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1) |
|
self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model) |
|
|
|
self.depth_num = 64 |
|
self.depth_start = 1 |
|
self.position_range = [-32.0, -32.0, -10.0, 32.0, 32.0, 10.0] |
|
self.position_dim = 3 * self.depth_num |
|
self.embed_dims = 256 |
|
self.sin_positional_encoding = dict( |
|
type='SinePositionalEncoding3D', num_feats=128, normalize=True) |
|
self.positional_encoding = build_positional_encoding( |
|
self.sin_positional_encoding) |
|
self.adapt_pos3d = nn.Sequential( |
|
nn.Conv2d(self.embed_dims * 3 // 2, self.embed_dims * 4, kernel_size=1, stride=1, padding=0), |
|
nn.ReLU(), |
|
nn.Conv2d(self.embed_dims * 4, self.embed_dims, kernel_size=1, stride=1, padding=0), |
|
) |
|
self.position_encoder = nn.Sequential( |
|
nn.Conv2d(self.position_dim, self.embed_dims * 4, kernel_size=1, stride=1, padding=0), |
|
nn.ReLU(), |
|
nn.Conv2d(self.embed_dims * 4, self.embed_dims, kernel_size=1, stride=1, padding=0), |
|
) |
|
|
|
self._trajectory_head = HydraTrajHead( |
|
num_poses=config.trajectory_sampling.num_poses, |
|
d_ffn=config.tf_d_ffn, |
|
d_model=config.tf_d_model, |
|
nhead=config.vadv2_head_nhead, |
|
nlayers=config.vadv2_head_nlayers, |
|
vocab_path=config.vocab_path, |
|
config=config |
|
) |
|
|
|
def inverse_sigmoid(self, x, eps=1e-6): |
|
"""Inverse sigmoid function. |
|
|
|
Args: |
|
x (Tensor): The input tensor. |
|
eps (float): A small value to avoid numerical issues. |
|
|
|
Returns: |
|
Tensor: The logit value of the input. |
|
""" |
|
x = x.clamp(min=eps, max=1 - eps) |
|
return torch.log(x / (1 - x)) |
|
|
|
def position_embedding(self, features, img_features): |
|
eps = 1e-5 |
|
img_features = img_features.unsqueeze(1) |
|
B, N, C, tar_H, tar_W = img_features.shape |
|
device = img_features.device |
|
crop_top = 28 |
|
crop_left = 416 |
|
H = [self._config.img_vert_anchors for _ in range(3)] |
|
W = [ |
|
self._config.img_horz_anchors * 1088 // (1088 * 2 + 1920), |
|
self._config.img_horz_anchors * 1920 // (1088 * 2 + 1920), |
|
self._config.img_horz_anchors * 1088 // (1088 * 2 + 1920) |
|
] |
|
|
|
|
|
coords_h_l = torch.arange(H[0], device=device).float() * 1080 / H[0] + crop_top / H[0] |
|
coords_w_l = torch.arange(W[0], device=device).float() * 1920 / W[0] + crop_left / W[0] |
|
|
|
coords_h_f = torch.arange(H[1], device=device).float() * 1080 / H[1] + crop_top / H[1] |
|
coords_w_f = torch.arange(W[1], device=device).float() * 1920 / W[1] |
|
|
|
coords_h_r = torch.arange(H[2], device=device).float() * 1080 / H[2] + crop_top / H[2] |
|
coords_w_r = torch.arange(W[2], device=device).float() * 1920 / W[2] + crop_left / W[2] |
|
|
|
index = torch.arange(start=0, end=self.depth_num, step=1, device=img_features.device).float() |
|
index_1 = index + 1 |
|
bin_size = (self.position_range[3] - self.depth_start) / (self.depth_num * (1 + self.depth_num)) |
|
coords_d = self.depth_start + bin_size * index * index_1 |
|
|
|
D = coords_d.shape[0] |
|
coords = [1] * 3 |
|
coords[0] = torch.stack(torch.meshgrid([coords_w_l, coords_h_l, coords_d])).permute(1, 2, 3, 0) |
|
coords[1] = torch.stack(torch.meshgrid([coords_w_f, coords_h_f, coords_d])).permute(1, 2, 3, 0) |
|
coords[2] = torch.stack(torch.meshgrid([coords_w_r, coords_h_r, coords_d])).permute(1, 2, 3, 0) |
|
|
|
coords[0][..., :2] = coords[0][..., :2] * torch.max(coords[0][..., 2:3], |
|
torch.ones_like(coords[0][..., 2:3]) * eps) |
|
coords[1][..., :2] = coords[1][..., :2] * torch.max(coords[1][..., 2:3], |
|
torch.ones_like(coords[1][..., 2:3]) * eps) |
|
coords[2][..., :2] = coords[2][..., :2] * torch.max(coords[2][..., 2:3], |
|
torch.ones_like(coords[2][..., 2:3]) * eps) |
|
|
|
|
|
|
|
pos_3d_embed = None |
|
for i in range(3): |
|
sensor2lidar_rotation = features["sensor2lidar_rotation"][i] |
|
sensor2lidar_translation = features["sensor2lidar_translation"][i] |
|
intrinsics = features["intrinsics"][i] |
|
combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics)).float() |
|
|
|
|
|
|
|
|
|
|
|
coords3d = coords[i].view(1, N, W[i], H[i], D, 3, 1).repeat(B, 1, 1, 1, 1, 1, |
|
1) |
|
combine = combine.view(B, N, 1, 1, 1, 3, 3).repeat(1, 1, W[i], H[i], D, 1, 1) |
|
coords3d = torch.matmul(combine, coords3d).squeeze(-1) |
|
sensor2lidar_translation = sensor2lidar_translation.view(B, N, 1, 1, 1, 3) |
|
coords3d += sensor2lidar_translation |
|
|
|
coords3d[..., 0:1] = (coords3d[..., 0:1] - self.position_range[0]) / ( |
|
self.position_range[3] - self.position_range[0]) |
|
coords3d[..., 1:2] = (coords3d[..., 1:2] - self.position_range[1]) / ( |
|
self.position_range[4] - self.position_range[1]) |
|
coords3d[..., 2:3] = (coords3d[..., 2:3] - self.position_range[2]) / ( |
|
self.position_range[5] - self.position_range[2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pos_3d_embed is None: |
|
pos_3d_embed = coords3d |
|
else: |
|
pos_3d_embed = torch.cat((pos_3d_embed, coords3d), dim=2) |
|
|
|
|
|
pos_3d_embed = pos_3d_embed.permute(0, 1, 4, 5, 3, 2).contiguous().view(B * N, -1, tar_H, tar_W) |
|
coords3d = self.inverse_sigmoid(pos_3d_embed) |
|
coords_position_embeding = self.position_encoder(coords3d) |
|
return coords_position_embeding.view(B, N, self.embed_dims, tar_H, tar_W) |
|
|
|
def forward(self, features: Dict[str, torch.Tensor], |
|
interpolated_traj=None) -> Dict[str, torch.Tensor]: |
|
|
|
camera_feature: torch.Tensor = features["camera_feature"][0] |
|
|
|
status_feature: torch.Tensor = features["status_feature"] |
|
|
|
batch_size = status_feature.shape[0] |
|
assert (camera_feature.shape[0] == batch_size) |
|
img_features = self._backbone(camera_feature) |
|
img_features = self.downscale_layer(img_features) |
|
input_img_h, input_img_w = img_features.size(-2), img_features.size(-1) |
|
masks = img_features.new_ones( |
|
(img_features.shape[0], 1, input_img_h, input_img_w)) |
|
|
|
coords_position_embeding = self.position_embedding(features, img_features) |
|
sin_embed = self.positional_encoding(masks) |
|
sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).view(img_features.size()) |
|
pos_embed = coords_position_embeding.squeeze(1) + sin_embed |
|
|
|
img_features = img_features + pos_embed |
|
img_features = img_features.flatten(-2, -1) |
|
img_features = img_features.permute(0, 2, 1) |
|
|
|
if self._config.num_ego_status == 1 and status_feature.shape[1] == 32: |
|
status_encoding = self._status_encoding(status_feature[:, :8]) |
|
else: |
|
status_encoding = self._status_encoding(status_feature) |
|
|
|
keyval = img_features |
|
|
|
keyval += self._keyval_embedding.weight[None, ...] |
|
|
|
output: Dict[str, torch.Tensor] = {} |
|
trajectory = self._trajectory_head(keyval, status_encoding, interpolated_traj) |
|
output.update(trajectory) |
|
|
|
return output |
|
|
|
|
|
class HydraTrajHead(nn.Module): |
|
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str, |
|
nhead: int, nlayers: int, config: Vadv2Config = None |
|
): |
|
super().__init__() |
|
self._num_poses = num_poses |
|
self.transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model, nhead, d_ffn, |
|
dropout=0.0, batch_first=True |
|
), nlayers |
|
) |
|
self.vocab = nn.Parameter( |
|
torch.from_numpy(np.load(vocab_path)), |
|
requires_grad=False |
|
) |
|
|
|
self.heads = nn.ModuleDict({ |
|
'noc': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'da': |
|
nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'ttc': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'comfort': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'progress': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
), |
|
'imi': nn.Sequential( |
|
nn.Linear(d_model, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, 1), |
|
) |
|
}) |
|
|
|
self.inference_imi_weight = config.inference_imi_weight |
|
self.inference_da_weight = config.inference_da_weight |
|
self.normalize_vocab_pos = config.normalize_vocab_pos |
|
if self.normalize_vocab_pos: |
|
self.encoder = MemoryEffTransformer( |
|
d_model=d_model, |
|
nhead=nhead, |
|
dim_feedforward=d_model * 4, |
|
dropout=0.0 |
|
) |
|
self.use_nerf = config.use_nerf |
|
|
|
if self.use_nerf: |
|
self.pos_embed = nn.Sequential( |
|
nn.Linear(1040, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_model), |
|
) |
|
else: |
|
self.pos_embed = nn.Sequential( |
|
nn.Linear(num_poses * 3, d_ffn), |
|
nn.ReLU(), |
|
nn.Linear(d_ffn, d_model), |
|
) |
|
|
|
def forward(self, bev_feature, status_encoding, interpolated_traj) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
vocab = self.vocab.data |
|
L, HORIZON, _ = vocab.shape |
|
B = bev_feature.shape[0] |
|
if self.use_nerf: |
|
vocab = torch.cat( |
|
[ |
|
nerf_positional_encoding(vocab[..., :2]), |
|
torch.cos(vocab[..., -1])[..., None], |
|
torch.sin(vocab[..., -1])[..., None], |
|
], dim=-1 |
|
) |
|
|
|
if self.normalize_vocab_pos: |
|
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None] |
|
embedded_vocab = self.encoder(embedded_vocab).repeat(B, 1, 1) |
|
else: |
|
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1) |
|
tr_out = self.transformer(embedded_vocab, bev_feature) |
|
dist_status = tr_out + status_encoding.unsqueeze(1) |
|
result = {} |
|
|
|
for k, head in self.heads.items(): |
|
if k == 'imi': |
|
result[k] = head(dist_status).squeeze(-1) |
|
else: |
|
result[k] = head(dist_status).squeeze(-1).sigmoid() |
|
|
|
scores = ( |
|
0.05 * result['imi'].softmax(-1).log() + |
|
0.5 * result['noc'].log() + |
|
0.5 * result['da'].log() + |
|
8.0 * (5 * result['ttc'] + 2 * result['comfort'] + 5 * result['progress']).log() |
|
) |
|
selected_indices = scores.argmax(1) |
|
result["trajectory"] = self.vocab.data[selected_indices] |
|
result["trajectory_vocab"] = self.vocab.data |
|
result["selected_indices"] = selected_indices |
|
return result |
|
|