import os import pickle from typing import Any, Union import numpy as np from pytorch_lightning.callbacks import ModelCheckpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from navsim.agents.hydra.hydra_config import HydraConfig from navsim.agents.hydra.hydra_features import HydraFeatureBuilder, HydraTargetBuilder from navsim.agents.hydra.hydra_model_pe_nodet_beta import HydraModelPENoDetBeta from navsim.agents.vadv2.vadv2_config import Vadv2Config from navsim.agents.vadv2.vadv2_loss import three_to_two_classes from navsim.common.dataclasses import SensorConfig from navsim.planning.training.abstract_feature_target_builder import ( AbstractFeatureBuilder, AbstractTargetBuilder, ) DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT') TRAJ_PDM_ROOT = os.getenv('NAVSIM_TRAJPDM_ROOT') from typing import Dict, List try: from navsim.agents.utils.positional_encoding import SinePositionalEncoding3D except: print('sine pe not registered') pass import pytorch_lightning as pl import torch import torch.nn.functional as F from navsim.agents.abstract_agent import AbstractAgent def hydra_nodet_beta_loss( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], predictions['ttc'], predictions['comfort'], predictions['progress']) imi = predictions['imi'] # 2 cls da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) progress_loss = F.l1_loss(progress, vocab_pdm_score['progress'].to(progress.dtype)) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss da_loss_final = config.trajectory_pdm_weight['da'] * da_loss ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss loss = ( imi_loss_final + noc_loss_final + da_loss_final + ttc_loss_final + progress_loss_final + comfort_loss_final ) return loss, { 'imi_loss': imi_loss_final, 'pdm_noc_loss': noc_loss_final, 'pdm_da_loss': da_loss_final, 'pdm_ttc_loss': ttc_loss_final, 'pdm_progress_loss': progress_loss_final, 'pdm_comfort_loss': comfort_loss_final } class HydraAgentPENoDetBeta(AbstractAgent): def __init__( self, config: HydraConfig, lr: float, checkpoint_path: str = None, pdm_split=None, metrics=None, ): super().__init__() config.trajectory_pdm_weight = { 'noc': 3.0, 'da': 3.0, 'ttc': config.ttc_weight, 'progress': config.progress_weight, 'comfort': 1.0, } self._config = config self._lr = lr self.metrics = metrics self._checkpoint_path = checkpoint_path self.vadv2_model = HydraModelPENoDetBeta(config) self.vocab_size = config.vocab_size self.backbone_wd = config.backbone_wd new_pkl_dir = f'vocab_score_full_{self.vocab_size}_navtrain' self.vocab_pdm_score_full = pickle.load( open(f'{TRAJ_PDM_ROOT}/{new_pkl_dir}/{pdm_split}.pkl', 'rb')) def name(self) -> str: """Inherited, see superclass.""" return self.__class__.__name__ def initialize(self) -> None: """Inherited, see superclass.""" # if torch.cuda.is_available(): # state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"] # else: # state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))[ # "state_dict"] state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"] self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) def get_sensor_config(self) -> SensorConfig: """Inherited, see superclass.""" return SensorConfig( cam_f0=[0, 1, 2, 3], cam_l0=[0, 1, 2, 3], cam_l1=[0, 1, 2, 3], cam_l2=[0, 1, 2, 3], cam_r0=[0, 1, 2, 3], cam_r1=[0, 1, 2, 3], cam_r2=[0, 1, 2, 3], cam_b0=[0, 1, 2, 3], lidar_pc=[], ) def get_target_builders(self) -> List[AbstractTargetBuilder]: return [HydraTargetBuilder(config=self._config)] def get_feature_builders(self) -> List[AbstractFeatureBuilder]: return [HydraFeatureBuilder(config=self._config)] def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.vadv2_model(features) def forward_train(self, features, interpolated_traj): return self.vadv2_model(features, interpolated_traj) def compute_loss( self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], tokens=None ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: # get the pdm score by tokens scores = {} for k in self.metrics: tmp = [self.vocab_pdm_score_full[token][k][None] for token in tokens] scores[k] = (torch.from_numpy(np.concatenate(tmp, axis=0)) .to(predictions['trajectory'].device)) return hydra_nodet_beta_loss(targets, predictions, self._config, scores) def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]: backbone_params_name = '_backbone.image_encoder' img_backbone_params = list( filter(lambda kv: backbone_params_name in kv[0], self.vadv2_model.named_parameters())) default_params = list(filter(lambda kv: backbone_params_name not in kv[0], self.vadv2_model.named_parameters())) params_lr_dict = [ {'params': [tmp[1] for tmp in default_params]}, { 'params': [tmp[1] for tmp in img_backbone_params], 'lr': self._lr * self._config.lr_mult_backbone, 'weight_decay': self.backbone_wd } ] return torch.optim.Adam(params_lr_dict, lr=self._lr) def get_training_callbacks(self) -> List[pl.Callback]: return [ # TransfuserCallback(self._config), ModelCheckpoint( save_top_k=30, monitor="val/loss_epoch", mode="min", dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/", filename="{epoch:02d}-{step:04d}", ) ]