|
from abc import abstractmethod, ABC
|
|
from typing import Dict, Union, List
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
|
|
from navsim.common.dataclasses import AgentInput, Trajectory, SensorConfig
|
|
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
|
|
|
|
|
|
class AbstractAgent(torch.nn.Module, ABC):
|
|
def __init__(
|
|
self,
|
|
requires_scene: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.requires_scene = requires_scene
|
|
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""
|
|
:return: string describing name of this agent.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_sensor_config(self) -> SensorConfig:
|
|
"""
|
|
:return: Dataclass defining the sensor configuration for lidar and cameras.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def initialize(self) -> None:
|
|
"""
|
|
Initialize agent
|
|
:param initialization: Initialization class.
|
|
"""
|
|
pass
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Forward pass of the agent.
|
|
:param features: Dictionary of features.
|
|
:return: Dictionary of predictions.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
|
|
"""
|
|
:return: List of target builders.
|
|
"""
|
|
raise NotImplementedError("No feature builders. Agent does not support training.")
|
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]:
|
|
"""
|
|
:return: List of feature builders.
|
|
"""
|
|
raise NotImplementedError("No target builders. Agent does not support training.")
|
|
|
|
def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
|
|
"""
|
|
Computes the ego vehicle trajectory.
|
|
:param current_input: Dataclass with agent inputs.
|
|
:return: Trajectory representing the predicted ego's position in future
|
|
"""
|
|
self.eval()
|
|
features: Dict[str, torch.Tensor] = {}
|
|
|
|
for builder in self.get_feature_builders():
|
|
features.update(builder.compute_features(agent_input))
|
|
|
|
|
|
features = {k: v.unsqueeze(0) for k, v in features.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
predictions = self.forward(features)
|
|
poses = predictions["trajectory"].squeeze(0).numpy()
|
|
|
|
|
|
return Trajectory(poses)
|
|
|
|
def compute_loss(
|
|
self,
|
|
features: Dict[str, torch.Tensor],
|
|
targets: Dict[str, torch.Tensor],
|
|
predictions: Dict[str, torch.Tensor],
|
|
tokens=None
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
"""
|
|
Computes the loss used for backpropagation based on the features, targets and model predictions.
|
|
"""
|
|
raise NotImplementedError("No loss. Agent does not support training.")
|
|
|
|
def get_optimizers(
|
|
self
|
|
) -> Union[
|
|
torch.optim.Optimizer,
|
|
Dict[str, Union[
|
|
torch.optim.Optimizer,
|
|
torch.optim.lr_scheduler.LRScheduler]
|
|
]
|
|
]:
|
|
"""
|
|
Returns the optimizers that are used by thy pytorch-lightning trainer.
|
|
Has to be either a single optimizer or a dict of optimizer and lr scheduler.
|
|
"""
|
|
raise NotImplementedError("No optimizers. Agent does not support training.")
|
|
|
|
def get_training_callbacks(
|
|
self
|
|
) -> List[pl.Callback]:
|
|
"""
|
|
Returns a list of pytorch-lightning callbacks that are used during training.
|
|
See navsim.planning.training.callbacks for examples.
|
|
"""
|
|
return []
|
|
|