|
from typing import Dict, Tuple, List
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.agents.vadv2.vadv2_agent import Vadv2Agent
|
|
from navsim.common.dataclasses import Trajectory
|
|
|
|
|
|
class AgentLightningModule(pl.LightningModule):
|
|
def __init__(
|
|
self,
|
|
agent: AbstractAgent,
|
|
):
|
|
super().__init__()
|
|
self.agent = agent
|
|
|
|
def _step(
|
|
self,
|
|
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]],
|
|
logging_prefix: str,
|
|
):
|
|
features, targets, tokens = batch
|
|
if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent):
|
|
prediction = self.agent.forward_train(features, targets['interpolated_traj'])
|
|
else:
|
|
prediction = self.agent.forward(features)
|
|
|
|
loss, loss_dict = self.agent.compute_loss(features, targets, prediction, tokens)
|
|
|
|
for k, v in loss_dict.items():
|
|
self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
|
self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
|
return loss
|
|
|
|
def training_step(
|
|
self,
|
|
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
|
|
batch_idx: int
|
|
):
|
|
return self._step(batch, "train")
|
|
|
|
def validation_step(
|
|
self,
|
|
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
|
|
batch_idx: int
|
|
):
|
|
return self._step(batch, "val")
|
|
|
|
def configure_optimizers(self):
|
|
return self.agent.get_optimizers()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_step(
|
|
self,
|
|
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
|
|
batch_idx: int
|
|
):
|
|
features, targets, tokens = batch
|
|
self.agent.eval()
|
|
with torch.no_grad():
|
|
predictions = self.agent.forward(features)
|
|
poses = predictions["trajectory"].cpu().numpy()
|
|
|
|
imis = predictions["imi"].softmax(-1).log().cpu().numpy()
|
|
nocs = predictions["noc"].log().cpu().numpy()
|
|
das = predictions["da"].log().cpu().numpy()
|
|
ttcs = predictions["ttc"].log().cpu().numpy()
|
|
comforts = predictions["comfort"].log().cpu().numpy()
|
|
if 'progress' in predictions:
|
|
progresses = predictions["progress"].log().cpu().numpy()
|
|
else:
|
|
progresses = [None for _ in range(len(tokens))]
|
|
if poses.shape[1] == 40:
|
|
interval_length = 0.1
|
|
else:
|
|
interval_length = 0.5
|
|
|
|
return {token: {
|
|
'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
|
|
'imi': imi,
|
|
'noc': noc,
|
|
'da': da,
|
|
'ttc': ttc,
|
|
'comfort': comfort,
|
|
'progress': progress
|
|
} for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses,
|
|
tokens)}
|
|
|