navsim_ours / navsim /agents /hydra_plantf /hydra_plantf_model.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
7.02 kB
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from navsim.agents.hydra_plantf.hydra_plantf_config import HydraPlantfConfig
from navsim.agents.hydra_plantf.model_utils import MapEncoder, AgentEncoder, CustomTransformerEncoderLayer
from navsim.agents.utils.attn import MemoryEffTransformer
from navsim.agents.utils.nerf import nerf_positional_encoding
from navsim.agents.vadv2.vadv2_config import Vadv2Config
class HydraPlantfModel(nn.Module):
def __init__(self, config: HydraPlantfConfig):
super().__init__()
self._config = config
self.map_encoder = MapEncoder(
dim=config.tf_d_model,
polygon_channel=6
)
self.agent_encoder = AgentEncoder(
agent_channel=8,
dim=config.tf_d_model,
)
# 4 layers
self.blocks = nn.ModuleList(
CustomTransformerEncoderLayer(dim=config.tf_d_model, num_heads=config.tf_num_head, drop_path=dp)
for dp in [x.item() for x in torch.linspace(0, 0.2, config.tf_num_encoder_layers)]
)
self.norm = nn.LayerNorm(config.tf_d_model)
self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model)
self._trajectory_head = HydraTrajPlantfHead(
num_poses=config.trajectory_sampling.num_poses,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
nhead=config.vadv2_head_nhead,
nlayers=config.vadv2_head_nlayers,
vocab_path=config.vocab_path,
config=config
)
def forward(self, features: Dict[str, torch.Tensor],
interpolated_traj=None) -> Dict[str, torch.Tensor]:
status_feature: torch.Tensor = features["status_feature"]
if self._config.num_ego_status == 1 and status_feature.shape[1] == 32:
status_encoding = self._status_encoding(status_feature[:, :8])
else:
status_encoding = self._status_encoding(status_feature)
agent_features = self.agent_encoder(features['agent'])
map_features = self.map_encoder(features['map'])
key_padding_mask = torch.cat([
~(features['agent']['valid_mask']),
~(features['map']['valid_mask'].any(-1))
], dim=-1)
x = torch.cat([agent_features, map_features], dim=1)
for blk in self.blocks:
x = blk(x, key_padding_mask=key_padding_mask)
keyval = self.norm(x)
output: Dict[str, torch.Tensor] = {}
trajectory = self._trajectory_head(keyval, status_encoding)
output.update(trajectory)
return output
class HydraTrajPlantfHead(nn.Module):
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str,
nhead: int, nlayers: int, config: Vadv2Config = None
):
super().__init__()
self._num_poses = num_poses
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model, nhead, d_ffn,
dropout=0.0, batch_first=True
), nlayers
)
self.vocab = nn.Parameter(
torch.from_numpy(np.load(vocab_path)),
requires_grad=False
)
self.heads = nn.ModuleDict({
'noc': nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, 1),
),
'da':
nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, 1),
),
'ttc': nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, 1),
),
'comfort': nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, 1),
),
'progress': nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, 1),
),
'imi': nn.Sequential(
nn.Linear(d_model, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, 1),
)
})
self.inference_imi_weight = config.inference_imi_weight
self.inference_da_weight = config.inference_da_weight
self.normalize_vocab_pos = config.normalize_vocab_pos
if self.normalize_vocab_pos:
self.encoder = MemoryEffTransformer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=0.0
)
self.use_nerf = config.use_nerf
if self.use_nerf:
self.pos_embed = nn.Sequential(
nn.Linear(1040, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, d_model),
)
else:
self.pos_embed = nn.Sequential(
nn.Linear(num_poses * 3, d_ffn),
nn.ReLU(),
nn.Linear(d_ffn, d_model),
)
def forward(self, bev_feature, status_encoding, interpolated_traj=None) -> Dict[str, torch.Tensor]:
# todo sinusoidal embedding
# vocab: 4096, 40, 3
# bev_feature: B, 32, C
# embedded_vocab: B, 4096, C
vocab = self.vocab.data
L, HORIZON, _ = vocab.shape
B = bev_feature.shape[0]
if self.use_nerf:
vocab = torch.cat(
[
nerf_positional_encoding(vocab[..., :2]),
torch.cos(vocab[..., -1])[..., None],
torch.sin(vocab[..., -1])[..., None],
], dim=-1
)
if self.normalize_vocab_pos:
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None]
embedded_vocab = self.encoder(embedded_vocab).repeat(B, 1, 1)
else:
embedded_vocab = self.pos_embed(vocab.view(L, -1))[None].repeat(B, 1, 1)
tr_out = self.transformer(embedded_vocab, bev_feature)
dist_status = tr_out + status_encoding.unsqueeze(1)
result = {}
# selected_indices: B,
for k, head in self.heads.items():
if k == 'imi':
result[k] = head(dist_status).squeeze(-1)
else:
result[k] = head(dist_status).squeeze(-1).sigmoid()
scores = (
0.05 * result['imi'].softmax(-1).log() +
0.5 * result['noc'].log() +
0.5 * result['da'].log() +
8.0 * (5 * result['ttc'] + 2 * result['comfort'] + 5 * result['progress']).log()
)
selected_indices = scores.argmax(1)
result["trajectory"] = self.vocab.data[selected_indices]
result["trajectory_vocab"] = self.vocab.data
result["selected_indices"] = selected_indices
return result