import os import pickle from typing import Any, List, Dict, Union, Optional import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import ModelCheckpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from navsim.agents.abstract_agent import AbstractAgent from navsim.agents.transfuser.transfuser_callback import TransfuserCallback from navsim.agents.vadv2.vadv2_features import ( Vadv2FeatureBuilder, Vadv2TargetBuilder, ) from navsim.agents.vadv2.vadv2_config import Vadv2Config from navsim.agents.vadv2.vadv2_loss import vadv2_loss_ori, vadv2_loss_center, vadv2_loss_center_woper from navsim.agents.vadv2.vadv2_model import Vadv2Model from navsim.common.dataclasses import SensorConfig from navsim.planning.training.abstract_feature_target_builder import ( AbstractFeatureBuilder, AbstractTargetBuilder, ) DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT') class Vadv2Agent(AbstractAgent): def __init__( self, config: Vadv2Config, lr: float, checkpoint_path: str = None, split=None, vocab_size=4096, closest=False, ori=False ): super().__init__() self._config = config self._lr = lr self._checkpoint_path = checkpoint_path self.vadv2_model = Vadv2Model(config) self.vocab_pdm_score = pickle.load(open(f'{DEVKIT_ROOT}/vocab_score_local/{split}.pkl', 'rb')) self.vocab_size = vocab_size def name(self) -> str: """Inherited, see superclass.""" return self.__class__.__name__ def initialize(self) -> None: """Inherited, see superclass.""" if torch.cuda.is_available(): state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"] else: state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))[ "state_dict"] self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) def get_sensor_config(self) -> SensorConfig: """Inherited, see superclass.""" return SensorConfig.build_mm_sensors() def get_target_builders(self) -> List[AbstractTargetBuilder]: return [Vadv2TargetBuilder(config=self._config)] def get_feature_builders(self) -> List[AbstractFeatureBuilder]: return [Vadv2FeatureBuilder(config=self._config)] def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.vadv2_model(features) def forward_train(self, features, interpolated_traj): return self.vadv2_model(features, interpolated_traj) def compute_loss( self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], tokens=None ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: # get the pdm score by tokens dummy_score = np.zeros(self._config.vocab_size, dtype=np.float32) curr_vocab_pdm_score = [self.vocab_pdm_score.get(token, dummy_score)[None] for token in tokens] curr_vocab_pdm_score = np.concatenate(curr_vocab_pdm_score, axis=0) if self._config.type == 'ori': return vadv2_loss_ori(targets, predictions, self._config, curr_vocab_pdm_score) elif self._config.type == 'center': return vadv2_loss_center(targets, predictions, self._config, curr_vocab_pdm_score) elif self._config.type == 'center_woper': return vadv2_loss_center_woper(targets, predictions, self._config, curr_vocab_pdm_score) else: raise NotImplementedError def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]: return torch.optim.Adam(self.vadv2_model.parameters(), lr=self._lr) def get_training_callbacks(self) -> List[pl.Callback]: return [TransfuserCallback(self._config), ModelCheckpoint( save_top_k=15, monitor="val/loss_epoch", mode="min", dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/", filename="{epoch:02d}-{step:04d}", ) ]