from typing import Dict, Tuple, List import pytorch_lightning as pl import torch from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling from torch import Tensor import torch.nn.functional as F from navsim.agents.abstract_agent import AbstractAgent from navsim.agents.vadv2.vadv2_agent import Vadv2Agent from navsim.common.dataclasses import Trajectory class AgentLightningModule(pl.LightningModule): def __init__( self, agent: AbstractAgent, ): super().__init__() self.agent = agent def _step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]], logging_prefix: str, ): features, targets, tokens = batch if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent): prediction = self.agent.forward_train(features, targets['interpolated_traj']) else: prediction = self.agent.forward(features) loss, loss_dict = self.agent.compute_loss(features, targets, prediction, tokens) for k, v in loss_dict.items(): self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) return loss def training_step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int ): return self._step(batch, "train") def validation_step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int ): return self._step(batch, "val") def configure_optimizers(self): return self.agent.get_optimizers() # ablate overall pdm score # def predict_step( # self, # batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], # batch_idx: int # ): # features, targets, tokens = batch # self.agent.eval() # with torch.no_grad(): # predictions = self.agent.forward(features) # poses = predictions["trajectory"].cpu().numpy() # if poses.shape[1] == 40: # interval_length = 0.1 # else: # interval_length = 0.5 # return {token: { # 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)), # } for pose, token in zip(poses, tokens)} # ablate post-processing # def predict_step( # self, # batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], # batch_idx: int # ): # features, _, tokens = batch # self.agent.eval() # K = 100 # # N_VOCAB, 40, 3 # vocab = self.agent.vadv2_model._trajectory_head.vocab # with torch.no_grad(): # predictions = self.agent.forward(features) # # poses = predictions["trajectory"].cpu().numpy() # # B, N_VOCAB # imi_score = predictions["trajectory_distribution"].softmax(-1).log() # # B, K # topk_scores, topk_inds = imi_score.topk(K, -1) # # B, K, 40->20, 3->2 # topk_trajs = vocab[topk_inds][:, :, :20, :2] # # B, 30, 5 (x,y,h,l,w) # agents = predictions["agent_states"].cpu().numpy() # # B, 7, H=128, W=256 # map = predictions["bev_semantic_map"].softmax(1).log().cpu().numpy() # B, _, H, W = map.shape # post_scores = topk_scores.clone() # # normalize trajs # topk_trajs[..., 0] = topk_trajs[..., 0] / 32 # topk_trajs[..., 1] = topk_trajs[..., 1] / 32 # # B, H, W # good_locs = map[:, 1:2] # bad_locs = map[:, 2:3] # post_scores += F.grid_sample(good_locs, topk_trajs, mode='nearest').sum((-1,)).squeeze(1) # post_scores -= F.grid_sample(bad_locs, topk_trajs, mode='nearest').sum((-1,)).squeeze(1) # post_ind = post_scores.argmax(-1) # poses = vocab[topk_inds[post_ind]].cpu().numpy() # if poses.shape[1] == 40: # interval_length = 0.1 # else: # interval_length = 0.5 # return {token: { # 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)), # } for pose, token in zip(poses, tokens)} # hydra-pdm def predict_step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int ): features, targets, tokens = batch self.agent.eval() with torch.no_grad(): predictions = self.agent.forward(features) poses = predictions["trajectory"].cpu().numpy() imis = predictions["imi"].softmax(-1).log().cpu().numpy() nocs = predictions["noc"].log().cpu().numpy() das = predictions["da"].log().cpu().numpy() ttcs = predictions["ttc"].log().cpu().numpy() comforts = predictions["comfort"].log().cpu().numpy() if 'progress' in predictions: progresses = predictions["progress"].log().cpu().numpy() else: progresses = [None for _ in range(len(tokens))] if poses.shape[1] == 40: interval_length = 0.1 else: interval_length = 0.5 return {token: { 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)), 'imi': imi, 'noc': noc, 'da': da, 'ttc': ttc, 'comfort': comfort, 'progress': progress } for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses, tokens)}