import os from functools import partial from typing import Any, Union from typing import Dict, List import pytorch_lightning as pl import torch import torch.nn as nn from pytorch_lightning.callbacks import ModelCheckpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from navsim.agents.abstract_agent import AbstractAgent from navsim.agents.dreamer.backbone import Backbone from navsim.agents.dreamer.dreamer_network import DreamerNetwork from navsim.agents.dreamer.dreamer_network_cond import DreamerNetworkCondition from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig from navsim.agents.dreamer.hydra_dreamer_loss_fn import latent_wm_loss from navsim.agents.dreamer.hydra_dreamer_wm_features import HydraDreamerWmFeatureBuilder, HydraDreamerWmTargetBuilder from navsim.agents.utils.layers import Mlp, NestedTensorBlock as Block from navsim.common.dataclasses import SensorConfig from navsim.planning.training.abstract_feature_target_builder import ( AbstractFeatureBuilder, AbstractTargetBuilder, ) NAVSIM_EXP_ROOT = os.getenv('NAVSIM_EXP_ROOT') DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT') TRAJ_PDM_ROOT = os.getenv('NAVSIM_TRAJPDM_ROOT') class HydraDreamerWmAgent(AbstractAgent): def __init__( self, config: HydraDreamerConfig, lr: float, checkpoint_path: str = None, pdm_split=None, metrics=None, conditional=False ): super().__init__() config.trajectory_pdm_weight = { 'noc': 3.0, 'da': 3.0, 'ttc': 2.0, 'progress': config.progress_weight, 'comfort': 1.0, } self._config = config self._lr = lr self.metrics = metrics self._checkpoint_path = checkpoint_path self.vocab_size = config.vocab_size self.backbone_wd = config.backbone_wd self.conditional = conditional if conditional: self.dreamer_network = DreamerNetworkCondition(config) else: self.dreamer_network = DreamerNetwork(config) def name(self) -> str: """Inherited, see superclass.""" return self.__class__.__name__ def initialize(self) -> None: """Inherited, see superclass.""" state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"] self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) def get_sensor_config(self) -> SensorConfig: """Inherited, see superclass.""" return SensorConfig( cam_f0=True, cam_l0=True, cam_l1=True, cam_l2=True, cam_r0=True, cam_r1=True, cam_r2=True, cam_b0=True, lidar_pc=[], ) def get_target_builders(self) -> List[AbstractTargetBuilder]: return [HydraDreamerWmTargetBuilder(config=self._config)] def get_feature_builders(self) -> List[AbstractFeatureBuilder]: return [HydraDreamerWmFeatureBuilder(config=self._config)] def _forward(self, features): return self.dreamer_network(features) def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._forward(features) def forward_train(self, features, interpolated_traj): return self._forward(features) def compute_loss( self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], tokens=None ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: return latent_wm_loss(targets, predictions, self._config, self.dreamer_network.fixed_vit) def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]: backbone_params_name = 'siamese_vit' img_backbone_params = list( filter(lambda kv: backbone_params_name in kv[0], self.dreamer_network.named_parameters()) ) default_params = list( filter(lambda kv: backbone_params_name not in kv[0], self.dreamer_network.named_parameters()) ) params_lr_dict = [ {'params': [tmp[1] for tmp in default_params]}, { 'params': [tmp[1] for tmp in img_backbone_params], 'lr': self._lr * self._config.lr_mult_backbone, 'weight_decay': self.backbone_wd } ] return torch.optim.Adam(params_lr_dict, lr=self._lr) def get_training_callbacks(self) -> List[pl.Callback]: return [ # TransfuserCallback(self._config), ModelCheckpoint( save_top_k=30, monitor="val/loss_epoch", mode="min", dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/", filename="{epoch:02d}-{step:04d}", ) ]