lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
import math
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from diffusers import DDIMScheduler
from navsim.agents.dm.backbone import DMBackbone
from navsim.agents.dm.dm_config import DMConfig
from navsim.agents.dm.utils import VerletStandardizer
from navsim.agents.transfuser.transfuser_model import AgentHead
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class DMModel(nn.Module):
def __init__(self, config: DMConfig):
super().__init__()
self._query_splits = [
config.num_bounding_boxes,
]
self._config = config
assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin']
if config.backbone_type == 'eva':
raise ValueError(f'{config.backbone_type} not supported')
elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or \
config.backbone_type == 'swin' or config.backbone_type == 'vit':
self._backbone = DMBackbone(config)
img_num = 2 if config.use_back_view else 1
self._keyval_embedding = nn.Embedding(
config.img_vert_anchors * config.img_horz_anchors * img_num, config.tf_d_model
) # 8x8 feature grid + trajectory
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)
# usually, the BEV features are variable in size.
self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1)
self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model)
tf_decoder_layer = nn.TransformerDecoderLayer(
d_model=config.tf_d_model,
nhead=config.tf_num_head,
dim_feedforward=config.tf_d_ffn,
dropout=config.tf_dropout,
batch_first=True,
)
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)
self._agent_head = AgentHead(
num_agents=config.num_bounding_boxes,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
)
self._trajectory_head = DMTrajHead(
num_poses=config.trajectory_sampling.num_poses,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
nhead=config.vadv2_head_nhead,
nlayers=config.vadv2_head_nlayers,
vocab_path=config.vocab_path,
config=config
)
def img_feat_blc(self, camera_feature):
img_features = self._backbone(camera_feature)
img_features = self.downscale_layer(img_features).flatten(-2, -1)
img_features = img_features.permute(0, 2, 1)
return img_features
def forward(self, features: Dict[str, torch.Tensor],
interpolated_traj=None) -> Dict[str, torch.Tensor]:
camera_feature: torch.Tensor = features["camera_feature"]
status_feature: torch.Tensor = features["status_feature"]
if isinstance(camera_feature, list):
camera_feature = camera_feature[-1]
# todo temp fix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# status_feature[:, 0] = 0.0
# status_feature[:, 1] = 1.0
# status_feature[:, 2] = 0.0
# status_feature[:, 3] = 0.0
batch_size = status_feature.shape[0]
img_features = self.img_feat_blc(camera_feature)
if self._config.use_back_view:
img_features_back = self.img_feat_blc(features["camera_feature_back"])
img_features = torch.cat([img_features, img_features_back], 1)
if self._config.num_ego_status == 1 and status_feature.shape[1] == 32:
status_encoding = self._status_encoding(status_feature[:, :8])
else:
status_encoding = self._status_encoding(status_feature)
keyval = img_features
keyval += self._keyval_embedding.weight[None, ...]
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
agents_query = self._tf_decoder(query, keyval)
output: Dict[str, torch.Tensor] = {}
trajectory = self._trajectory_head(keyval, status_encoding, features['history_waypoints'])
output.update(trajectory)
agents = self._agent_head(agents_query)
output.update(agents)
return output
class DMTrajHead(nn.Module):
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str,
nhead: int, nlayers: int, config: DMConfig = None
):
super().__init__()
self.d_model = d_model
self.config = config
self._num_poses = num_poses
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model, nhead, d_ffn,
dropout=0.0, batch_first=True
), nlayers
)
self.vocab = nn.Parameter(
torch.from_numpy(np.load(vocab_path)),
requires_grad=False
)
self.H = config.trajectory_sampling.num_poses
self.T = config.T
self.standardizer = VerletStandardizer()
self.decoder_mlp = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.H * 3)
)
self.encoder_mlp = nn.Sequential(
nn.Linear(self.H * 3, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.d_model),
)
self.sigma_encoder = nn.Sequential(
SinusoidalPosEmb(self.d_model),
)
self.scheduler = DDIMScheduler(
num_train_timesteps=self.T,
beta_schedule='scaled_linear',
prediction_type='epsilon',
)
self.scheduler.set_timesteps(self.T)
def denoise(self, ego_trajectory, env_features, status_encoding, timesteps):
B = ego_trajectory.shape[0]
ego_trajectory = ego_trajectory.reshape(B, -1).to(torch.float32)
sigma = timesteps.reshape(-1, 1)
if sigma.numel() == 1:
sigma = sigma.repeat(B, 1)
sigma = sigma.float() / self.T
sigma_embeddings = self.sigma_encoder(sigma).squeeze(1)
ego_emb = self.encoder_mlp(ego_trajectory) + status_encoding + sigma_embeddings
ego_attn = self.transformer(ego_emb[:, None], env_features)
out = self.decoder_mlp(ego_attn).reshape(B, -1)
return out
def forward(self, bev_feature, status_encoding, history_waypoints) -> Dict[str, torch.Tensor]:
# todo sinusoidal embedding
# vocab: 4096, 40, 3
# bev_feature: B, 32, C
# embedded_vocab: B, 4096, C
B = bev_feature.shape[0]
result = {}
if not self.config.is_training:
ego_trajectory = torch.randn((B, self.H * 3),
device=bev_feature.device)
timesteps = self.scheduler.timesteps
residual = torch.zeros_like(ego_trajectory)
for t in timesteps:
with torch.no_grad():
residual += self.denoise(
ego_trajectory,
bev_feature,
status_encoding,
t.to(ego_trajectory.device)
)
out = self.scheduler.step(residual, t, ego_trajectory)
ego_trajectory = out.prev_sample
ego_trajectory = self.standardizer.untransform_features(ego_trajectory, history_waypoints)
result["trajectory"] = ego_trajectory.reshape(B, self.H, 3)
result['imi'], result['noc'], result['da'], result['ttc'], result['comfort'], result['progress'] = (
torch.ones((B, 4096)) for _ in range(6)
)
result['history_waypoints'] = history_waypoints
result['env_features'] = bev_feature
result['status_encoding'] = status_encoding
return result