File size: 3,901 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from abc import abstractmethod, ABC
from typing import Dict, Union, List
import pytorch_lightning as pl
import torch
from navsim.common.dataclasses import AgentInput, Trajectory, SensorConfig
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
class AbstractAgent(torch.nn.Module, ABC):
def __init__(
self,
requires_scene: bool = False,
):
super().__init__()
self.requires_scene = requires_scene
@abstractmethod
def name(self) -> str:
"""
:return: string describing name of this agent.
"""
pass
@abstractmethod
def get_sensor_config(self) -> SensorConfig:
"""
:return: Dataclass defining the sensor configuration for lidar and cameras.
"""
pass
@abstractmethod
def initialize(self) -> None:
"""
Initialize agent
:param initialization: Initialization class.
"""
pass
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass of the agent.
:param features: Dictionary of features.
:return: Dictionary of predictions.
"""
raise NotImplementedError
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
"""
:return: List of target builders.
"""
raise NotImplementedError("No feature builders. Agent does not support training.")
def get_target_builders(self) -> List[AbstractTargetBuilder]:
"""
:return: List of feature builders.
"""
raise NotImplementedError("No target builders. Agent does not support training.")
def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
"""
Computes the ego vehicle trajectory.
:param current_input: Dataclass with agent inputs.
:return: Trajectory representing the predicted ego's position in future
"""
self.eval()
features: Dict[str, torch.Tensor] = {}
# build features
for builder in self.get_feature_builders():
features.update(builder.compute_features(agent_input))
# add batch dimension
features = {k: v.unsqueeze(0) for k, v in features.items()}
# forward pass
with torch.no_grad():
predictions = self.forward(features)
poses = predictions["trajectory"].squeeze(0).numpy()
# extract trajectory
return Trajectory(poses)
def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
tokens=None
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Computes the loss used for backpropagation based on the features, targets and model predictions.
"""
raise NotImplementedError("No loss. Agent does not support training.")
def get_optimizers(
self
) -> Union[
torch.optim.Optimizer,
Dict[str, Union[
torch.optim.Optimizer,
torch.optim.lr_scheduler.LRScheduler]
]
]:
"""
Returns the optimizers that are used by thy pytorch-lightning trainer.
Has to be either a single optimizer or a dict of optimizer and lr scheduler.
"""
raise NotImplementedError("No optimizers. Agent does not support training.")
def get_training_callbacks(
self
) -> List[pl.Callback]:
"""
Returns a list of pytorch-lightning callbacks that are used during training.
See navsim.planning.training.callbacks for examples.
"""
return []
|