navsim_ours / navsim /planning /training /agent_lightning_module.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
6.19 kB
from typing import Dict, Tuple, List
import pytorch_lightning as pl
import torch
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
from torch import Tensor
import torch.nn.functional as F
from navsim.agents.abstract_agent import AbstractAgent
from navsim.agents.vadv2.vadv2_agent import Vadv2Agent
from navsim.common.dataclasses import Trajectory
class AgentLightningModule(pl.LightningModule):
def __init__(
self,
agent: AbstractAgent,
):
super().__init__()
self.agent = agent
def _step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]],
logging_prefix: str,
):
features, targets, tokens = batch
if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent):
prediction = self.agent.forward_train(features, targets['interpolated_traj'])
else:
prediction = self.agent.forward(features)
loss, loss_dict = self.agent.compute_loss(features, targets, prediction, tokens)
for k, v in loss_dict.items():
self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
return loss
def training_step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
batch_idx: int
):
return self._step(batch, "train")
def validation_step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
batch_idx: int
):
return self._step(batch, "val")
def configure_optimizers(self):
return self.agent.get_optimizers()
# ablate overall pdm score
# def predict_step(
# self,
# batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
# batch_idx: int
# ):
# features, targets, tokens = batch
# self.agent.eval()
# with torch.no_grad():
# predictions = self.agent.forward(features)
# poses = predictions["trajectory"].cpu().numpy()
# if poses.shape[1] == 40:
# interval_length = 0.1
# else:
# interval_length = 0.5
# return {token: {
# 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
# } for pose, token in zip(poses, tokens)}
# ablate post-processing
# def predict_step(
# self,
# batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
# batch_idx: int
# ):
# features, _, tokens = batch
# self.agent.eval()
# K = 100
# # N_VOCAB, 40, 3
# vocab = self.agent.vadv2_model._trajectory_head.vocab
# with torch.no_grad():
# predictions = self.agent.forward(features)
# # poses = predictions["trajectory"].cpu().numpy()
# # B, N_VOCAB
# imi_score = predictions["trajectory_distribution"].softmax(-1).log()
# # B, K
# topk_scores, topk_inds = imi_score.topk(K, -1)
# # B, K, 40->20, 3->2
# topk_trajs = vocab[topk_inds][:, :, :20, :2]
# # B, 30, 5 (x,y,h,l,w)
# agents = predictions["agent_states"].cpu().numpy()
# # B, 7, H=128, W=256
# map = predictions["bev_semantic_map"].softmax(1).log().cpu().numpy()
# B, _, H, W = map.shape
# post_scores = topk_scores.clone()
# # normalize trajs
# topk_trajs[..., 0] = topk_trajs[..., 0] / 32
# topk_trajs[..., 1] = topk_trajs[..., 1] / 32
# # B, H, W
# good_locs = map[:, 1:2]
# bad_locs = map[:, 2:3]
# post_scores += F.grid_sample(good_locs, topk_trajs, mode='nearest').sum((-1,)).squeeze(1)
# post_scores -= F.grid_sample(bad_locs, topk_trajs, mode='nearest').sum((-1,)).squeeze(1)
# post_ind = post_scores.argmax(-1)
# poses = vocab[topk_inds[post_ind]].cpu().numpy()
# if poses.shape[1] == 40:
# interval_length = 0.1
# else:
# interval_length = 0.5
# return {token: {
# 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
# } for pose, token in zip(poses, tokens)}
# hydra-pdm
def predict_step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
batch_idx: int
):
features, targets, tokens = batch
self.agent.eval()
with torch.no_grad():
predictions = self.agent.forward(features)
poses = predictions["trajectory"].cpu().numpy()
imis = predictions["imi"].softmax(-1).log().cpu().numpy()
nocs = predictions["noc"].log().cpu().numpy()
das = predictions["da"].log().cpu().numpy()
ttcs = predictions["ttc"].log().cpu().numpy()
comforts = predictions["comfort"].log().cpu().numpy()
if 'progress' in predictions:
progresses = predictions["progress"].log().cpu().numpy()
else:
progresses = [None for _ in range(len(tokens))]
if poses.shape[1] == 40:
interval_length = 0.1
else:
interval_length = 0.5
return {token: {
'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
'imi': imi,
'noc': noc,
'da': da,
'ttc': ttc,
'comfort': comfort,
'progress': progress
} for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses,
tokens)}