|
import os
|
|
import pickle
|
|
from typing import Any, List, Dict, Union, Optional
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.agents.transfuser.transfuser_callback import TransfuserCallback
|
|
from navsim.agents.vadv2.vadv2_features import (
|
|
Vadv2FeatureBuilder,
|
|
Vadv2TargetBuilder,
|
|
)
|
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config
|
|
from navsim.agents.vadv2.vadv2_loss import vadv2_loss_ori, vadv2_loss_center, vadv2_loss_center_woper
|
|
from navsim.agents.vadv2.vadv2_model import Vadv2Model
|
|
from navsim.common.dataclasses import SensorConfig
|
|
from navsim.planning.training.abstract_feature_target_builder import (
|
|
AbstractFeatureBuilder,
|
|
AbstractTargetBuilder,
|
|
)
|
|
|
|
DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT')
|
|
|
|
|
|
class Vadv2Agent(AbstractAgent):
|
|
def __init__(
|
|
self,
|
|
config: Vadv2Config,
|
|
lr: float,
|
|
checkpoint_path: str = None,
|
|
split=None,
|
|
vocab_size=4096,
|
|
closest=False,
|
|
ori=False
|
|
):
|
|
super().__init__()
|
|
|
|
self._config = config
|
|
self._lr = lr
|
|
|
|
self._checkpoint_path = checkpoint_path
|
|
self.vadv2_model = Vadv2Model(config)
|
|
self.vocab_pdm_score = pickle.load(open(f'{DEVKIT_ROOT}/vocab_score_local/{split}.pkl', 'rb'))
|
|
self.vocab_size = vocab_size
|
|
|
|
|
|
def name(self) -> str:
|
|
"""Inherited, see superclass."""
|
|
|
|
return self.__class__.__name__
|
|
|
|
def initialize(self) -> None:
|
|
"""Inherited, see superclass."""
|
|
if torch.cuda.is_available():
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
|
|
else:
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))[
|
|
"state_dict"]
|
|
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
|
|
|
|
def get_sensor_config(self) -> SensorConfig:
|
|
"""Inherited, see superclass."""
|
|
return SensorConfig.build_mm_sensors()
|
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]:
|
|
return [Vadv2TargetBuilder(config=self._config)]
|
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
|
|
return [Vadv2FeatureBuilder(config=self._config)]
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
return self.vadv2_model(features)
|
|
|
|
def forward_train(self, features, interpolated_traj):
|
|
return self.vadv2_model(features, interpolated_traj)
|
|
|
|
def compute_loss(
|
|
self,
|
|
features: Dict[str, torch.Tensor],
|
|
targets: Dict[str, torch.Tensor],
|
|
predictions: Dict[str, torch.Tensor],
|
|
tokens=None
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
|
dummy_score = np.zeros(self._config.vocab_size, dtype=np.float32)
|
|
curr_vocab_pdm_score = [self.vocab_pdm_score.get(token, dummy_score)[None] for token in tokens]
|
|
curr_vocab_pdm_score = np.concatenate(curr_vocab_pdm_score, axis=0)
|
|
if self._config.type == 'ori':
|
|
return vadv2_loss_ori(targets, predictions, self._config, curr_vocab_pdm_score)
|
|
elif self._config.type == 'center':
|
|
return vadv2_loss_center(targets, predictions, self._config, curr_vocab_pdm_score)
|
|
elif self._config.type == 'center_woper':
|
|
return vadv2_loss_center_woper(targets, predictions, self._config, curr_vocab_pdm_score)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]:
|
|
return torch.optim.Adam(self.vadv2_model.parameters(), lr=self._lr)
|
|
|
|
def get_training_callbacks(self) -> List[pl.Callback]:
|
|
return [TransfuserCallback(self._config),
|
|
ModelCheckpoint(
|
|
save_top_k=15,
|
|
monitor="val/loss_epoch",
|
|
mode="min",
|
|
dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/",
|
|
filename="{epoch:02d}-{step:04d}",
|
|
)
|
|
]
|
|
|