|
from __future__ import annotations
|
|
|
|
from typing import Any, List, Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
|
|
from det_map.data.datasets.dataclasses import SensorConfig, Scene
|
|
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
|
|
|
|
|
|
class DetTargetBuilder(AbstractTargetBuilder):
|
|
def __init__(self, pipelines):
|
|
super().__init__()
|
|
self.pipelines = pipelines
|
|
|
|
|
|
def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]:
|
|
anno_boxes = [frame.annotations.boxes for frame in scene.frames]
|
|
labels = [frame.annotations.names for frame in scene.frames]
|
|
velos = [frame.annotations.velocity_3d[:, :2] for frame in scene.frames]
|
|
final = [torch.from_numpy(np.concatenate([box, velo], axis=-1))
|
|
for box, velo in zip(anno_boxes, velos)]
|
|
|
|
return {"dets": final, "labels": labels}
|
|
|
|
|
|
class DetAgent(AbstractAgent):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
pipelines,
|
|
lr: float,
|
|
checkpoint_path: str = None, **kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
self.pipelines = pipelines
|
|
self._checkpoint_path = checkpoint_path
|
|
self._lr = lr
|
|
|
|
def name(self) -> str:
|
|
"""Inherited, see superclass."""
|
|
|
|
return self.__class__.__name__
|
|
|
|
def initialize(self) -> None:
|
|
"""Inherited, see superclass."""
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
|
|
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
|
|
|
|
def get_sensor_config(self) -> SensorConfig:
|
|
"""Inherited, see superclass."""
|
|
return SensorConfig.build_all_sensors(True)
|
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]:
|
|
return [
|
|
DetTargetBuilder(self.pipelines),
|
|
]
|
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
|
|
return [
|
|
LiDARCameraFeatureBuilder(self.pipelines)
|
|
]
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
return {"dets": None}
|
|
|
|
def compute_loss(
|
|
self,
|
|
features: Dict[str, torch.Tensor],
|
|
targets: Dict[str, torch.Tensor],
|
|
predictions: Dict[str, torch.Tensor],
|
|
) -> torch.Tensor:
|
|
return torch.nn.functional.l1_loss(predictions["dets"], targets["dets"])
|
|
|
|
def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
|
|
return torch.optim.Adam(self._mlp.parameters(), lr=self._lr)
|
|
|