from typing import Any, List, Dict, Union import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler import pytorch_lightning as pl from navsim.agents.abstract_agent import AbstractAgent from navsim.common.dataclasses import SensorConfig from navsim.planning.training.abstract_feature_target_builder import ( AbstractFeatureBuilder, AbstractTargetBuilder, ) from navsim.agents.transfuser.transfuser_config import TransfuserConfig from navsim.agents.transfuser.transfuser_model import TransfuserModel from navsim.agents.transfuser.transfuser_callback import TransfuserCallback from navsim.agents.transfuser.transfuser_loss import transfuser_loss from navsim.agents.transfuser.transfuser_features import ( TransfuserFeatureBuilder, TransfuserTargetBuilder, ) class TransfuserAgent(AbstractAgent): def __init__( self, config: TransfuserConfig, lr: float, checkpoint_path: str = None, ): super().__init__() self._config = config self._lr = lr self._checkpoint_path = checkpoint_path self._transfuser_model = TransfuserModel(config) def name(self) -> str: """Inherited, see superclass.""" return self.__class__.__name__ def initialize(self) -> None: """Inherited, see superclass.""" if torch.cuda.is_available(): state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"] else: state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"] self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) def get_sensor_config(self) -> SensorConfig: """Inherited, see superclass.""" return SensorConfig.build_all_sensors(include=[3]) def get_target_builders(self) -> List[AbstractTargetBuilder]: return [TransfuserTargetBuilder(config=self._config)] def get_feature_builders(self) -> List[AbstractFeatureBuilder]: return [TransfuserFeatureBuilder(config=self._config)] def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._transfuser_model(features) def compute_loss( self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: return transfuser_loss(targets, predictions, self._config) def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]: return torch.optim.Adam(self._transfuser_model.parameters(), lr=self._lr) def get_training_callbacks(self) -> List[pl.Callback]: return [TransfuserCallback(self._config)]