navsim_ours / navsim /agents /dreamer /hydra_dreamer_wm_agent.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
5.04 kB
import os
from functools import partial
from typing import Any, Union
from typing import Dict, List
import pytorch_lightning as pl
import torch
import torch.nn as nn
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from navsim.agents.abstract_agent import AbstractAgent
from navsim.agents.dreamer.backbone import Backbone
from navsim.agents.dreamer.dreamer_network import DreamerNetwork
from navsim.agents.dreamer.dreamer_network_cond import DreamerNetworkCondition
from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig
from navsim.agents.dreamer.hydra_dreamer_loss_fn import latent_wm_loss
from navsim.agents.dreamer.hydra_dreamer_wm_features import HydraDreamerWmFeatureBuilder, HydraDreamerWmTargetBuilder
from navsim.agents.utils.layers import Mlp, NestedTensorBlock as Block
from navsim.common.dataclasses import SensorConfig
from navsim.planning.training.abstract_feature_target_builder import (
AbstractFeatureBuilder,
AbstractTargetBuilder,
)
NAVSIM_EXP_ROOT = os.getenv('NAVSIM_EXP_ROOT')
DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT')
TRAJ_PDM_ROOT = os.getenv('NAVSIM_TRAJPDM_ROOT')
class HydraDreamerWmAgent(AbstractAgent):
def __init__(
self,
config: HydraDreamerConfig,
lr: float,
checkpoint_path: str = None,
pdm_split=None,
metrics=None,
conditional=False
):
super().__init__()
config.trajectory_pdm_weight = {
'noc': 3.0,
'da': 3.0,
'ttc': 2.0,
'progress': config.progress_weight,
'comfort': 1.0,
}
self._config = config
self._lr = lr
self.metrics = metrics
self._checkpoint_path = checkpoint_path
self.vocab_size = config.vocab_size
self.backbone_wd = config.backbone_wd
self.conditional = conditional
if conditional:
self.dreamer_network = DreamerNetworkCondition(config)
else:
self.dreamer_network = DreamerNetwork(config)
def name(self) -> str:
"""Inherited, see superclass."""
return self.__class__.__name__
def initialize(self) -> None:
"""Inherited, see superclass."""
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"]
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
def get_sensor_config(self) -> SensorConfig:
"""Inherited, see superclass."""
return SensorConfig(
cam_f0=True,
cam_l0=True,
cam_l1=True,
cam_l2=True,
cam_r0=True,
cam_r1=True,
cam_r2=True,
cam_b0=True,
lidar_pc=[],
)
def get_target_builders(self) -> List[AbstractTargetBuilder]:
return [HydraDreamerWmTargetBuilder(config=self._config)]
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
return [HydraDreamerWmFeatureBuilder(config=self._config)]
def _forward(self, features):
return self.dreamer_network(features)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self._forward(features)
def forward_train(self, features, interpolated_traj):
return self._forward(features)
def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
tokens=None
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
return latent_wm_loss(targets, predictions, self._config, self.dreamer_network.fixed_vit)
def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]:
backbone_params_name = 'siamese_vit'
img_backbone_params = list(
filter(lambda kv: backbone_params_name in kv[0], self.dreamer_network.named_parameters())
)
default_params = list(
filter(lambda kv: backbone_params_name not in kv[0], self.dreamer_network.named_parameters())
)
params_lr_dict = [
{'params': [tmp[1] for tmp in default_params]},
{
'params': [tmp[1] for tmp in img_backbone_params],
'lr': self._lr * self._config.lr_mult_backbone,
'weight_decay': self.backbone_wd
}
]
return torch.optim.Adam(params_lr_dict, lr=self._lr)
def get_training_callbacks(self) -> List[pl.Callback]:
return [
# TransfuserCallback(self._config),
ModelCheckpoint(
save_top_k=30,
monitor="val/loss_epoch",
mode="min",
dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/",
filename="{epoch:02d}-{step:04d}",
)
]