File size: 3,901 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from abc import abstractmethod, ABC
from typing import Dict, Union, List

import pytorch_lightning as pl
import torch

from navsim.common.dataclasses import AgentInput, Trajectory, SensorConfig
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder


class AbstractAgent(torch.nn.Module, ABC):
    def __init__(

            self,

            requires_scene: bool = False,

    ):
        super().__init__()
        self.requires_scene = requires_scene

    @abstractmethod
    def name(self) -> str:
        """

        :return: string describing name of this agent.

        """
        pass

    @abstractmethod
    def get_sensor_config(self) -> SensorConfig:
        """

        :return: Dataclass defining the sensor configuration for lidar and cameras.

        """
        pass

    @abstractmethod
    def initialize(self) -> None:
        """

        Initialize agent

        :param initialization: Initialization class.

        """
        pass

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """

        Forward pass of the agent.

        :param features: Dictionary of features.

        :return: Dictionary of predictions.

        """
        raise NotImplementedError

    def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
        """

        :return: List of target builders.

        """
        raise NotImplementedError("No feature builders. Agent does not support training.")

    def get_target_builders(self) -> List[AbstractTargetBuilder]:
        """

        :return: List of feature builders.

        """
        raise NotImplementedError("No target builders. Agent does not support training.")

    def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
        """

        Computes the ego vehicle trajectory.

        :param current_input: Dataclass with agent inputs.

        :return: Trajectory representing the predicted ego's position in future

        """
        self.eval()
        features: Dict[str, torch.Tensor] = {}
        # build features
        for builder in self.get_feature_builders():
            features.update(builder.compute_features(agent_input))

        # add batch dimension
        features = {k: v.unsqueeze(0) for k, v in features.items()}

        # forward pass
        with torch.no_grad():
            predictions = self.forward(features)
            poses = predictions["trajectory"].squeeze(0).numpy()

        # extract trajectory
        return Trajectory(poses)

    def compute_loss(

            self,

            features: Dict[str, torch.Tensor],

            targets: Dict[str, torch.Tensor],

            predictions: Dict[str, torch.Tensor],

            tokens=None

    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        """

        Computes the loss used for backpropagation based on the features, targets and model predictions.

        """
        raise NotImplementedError("No loss. Agent does not support training.")

    def get_optimizers(

            self

    ) -> Union[
        torch.optim.Optimizer,
        Dict[str, Union[
            torch.optim.Optimizer,
            torch.optim.lr_scheduler.LRScheduler]
        ]
    ]:
        """

        Returns the optimizers that are used by thy pytorch-lightning trainer.

        Has to be either a single optimizer or a dict of optimizer and lr scheduler.

        """
        raise NotImplementedError("No optimizers. Agent does not support training.")

    def get_training_callbacks(

            self

    ) -> List[pl.Callback]:
        """

        Returns a list of pytorch-lightning callbacks that are used during training.

        See navsim.planning.training.callbacks for examples.

        """
        return []