from typing import Dict, Tuple, List import pytorch_lightning as pl import torch from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling from torch import Tensor from navsim.agents.abstract_agent import AbstractAgent from navsim.agents.vadv2.vadv2_agent import Vadv2Agent from navsim.common.dataclasses import Trajectory class AgentLightningModuleMap(pl.LightningModule): def __init__( self, agent: AbstractAgent, ): super().__init__() self.agent = agent def _step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]], logging_prefix: str, ): features, targets = batch if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent): prediction = self.agent.forward_train(features, targets['interpolated_traj']) else: prediction = self.agent.forward(features) loss, loss_dict = self.agent.compute_loss(features, targets, prediction) for k, v in loss_dict.items(): self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) return loss def training_step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int ): return self._step(batch, "train") def validation_step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int ): return self._step(batch, "val") def configure_optimizers(self): return self.agent.get_optimizers() def predict_step( self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int ): features, targets, tokens = batch self.agent.eval() with torch.no_grad(): predictions = self.agent.forward(features) poses = predictions["trajectory"].cpu().numpy() imis = predictions["imi"].softmax(-1).log().cpu().numpy() nocs = predictions["noc"].log().cpu().numpy() das = predictions["da"].log().cpu().numpy() ttcs = predictions["ttc"].log().cpu().numpy() comforts = predictions["comfort"].log().cpu().numpy() progresses = predictions["progress"].log().cpu().numpy() if poses.shape[1] == 40: interval_length = 0.1 else: interval_length = 0.5 return {token: { 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)), 'imi': imi, 'noc': noc, 'da': da, 'ttc': ttc, 'comfort': comfort, 'progress': progress } for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses, tokens)} # def on_after_backward(self) -> None: # print("on_after_backward enter") # for name, param in self.named_parameters(): # if param.grad is None: # print(name) # print("on_after_backward exit")