from __future__ import annotations from typing import Any, List, Dict import numpy as np import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from det_map.data.datasets.dataclasses import SensorConfig, Scene from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder from navsim.agents.abstract_agent import AbstractAgent from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder class DetTargetBuilder(AbstractTargetBuilder): def __init__(self, pipelines): super().__init__() self.pipelines = pipelines # self.vehicle_params = get_pacifica_parameters() def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]: anno_boxes = [frame.annotations.boxes for frame in scene.frames] labels = [frame.annotations.names for frame in scene.frames] velos = [frame.annotations.velocity_3d[:, :2] for frame in scene.frames] final = [torch.from_numpy(np.concatenate([box, velo], axis=-1)) for box, velo in zip(anno_boxes, velos)] # final box should be [x,y,z,l,w,h,theta,vx,vy] return {"dets": final, "labels": labels} class DetAgent(AbstractAgent): def __init__( self, model, pipelines, lr: float, checkpoint_path: str = None, **kwargs ): super().__init__() # todo eval everything self.model = model self.pipelines = pipelines self._checkpoint_path = checkpoint_path self._lr = lr def name(self) -> str: """Inherited, see superclass.""" return self.__class__.__name__ def initialize(self) -> None: """Inherited, see superclass.""" state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"] self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) def get_sensor_config(self) -> SensorConfig: """Inherited, see superclass.""" return SensorConfig.build_all_sensors(True) def get_target_builders(self) -> List[AbstractTargetBuilder]: return [ DetTargetBuilder(self.pipelines), ] def get_feature_builders(self) -> List[AbstractFeatureBuilder]: return [ LiDARCameraFeatureBuilder(self.pipelines) ] def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {"dets": None} def compute_loss( self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], ) -> torch.Tensor: return torch.nn.functional.l1_loss(predictions["dets"], targets["dets"]) def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]: return torch.optim.Adam(self._mlp.parameters(), lr=self._lr)