File size: 3,098 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Dict

import cv2
import numpy as np
import torch
from torchvision import transforms

from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig
from navsim.common.dataclasses import AgentInput, Scene
from navsim.common.dataclasses import Cameras
from navsim.planning.training.abstract_feature_target_builder import (
    AbstractFeatureBuilder,
    AbstractTargetBuilder,
)


def cat_flr_imgs(camera: Cameras, config: HydraDreamerConfig):
    l0 = camera.cam_l0.image[28:-28, 416:-416]
    f0 = camera.cam_f0.image[28:-28]
    r0 = camera.cam_r0.image[28:-28, 416:-416]

    stitched_image = np.concatenate([l0, f0, r0], axis=1)
    resized_image = cv2.resize(stitched_image, (config.camera_width, config.camera_height))
    tensor_image = transforms.ToTensor()(resized_image)
    return tensor_image


class HydraDreamerWmFeatureBuilder(AbstractFeatureBuilder):
    def __init__(self, config: HydraDreamerConfig):
        super().__init__()
        self._config = config

    def get_unique_name(self) -> str:
        """Inherited, see superclass."""
        return "hydra_dreamer_wm_feature"

    def _get_camera_feature(self, agent_input: AgentInput):
        """
        Extract stitched camera from AgentInput
        :param agent_input: input dataclass
        :return: stitched front view image as torch tensor
        """

        cameras = agent_input.cameras[:3]
        image_list = []
        for camera in cameras:
            image_list.append(cat_flr_imgs(camera, self._config))

        return image_list

    def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
        """Inherited, see superclass."""
        features = {}
        ego_status_list = []
        for i in range(self._config.num_ego_status):
            # i=0: idx=-1
            # i=1: idx=-2
            # i=2: idx=-3
            # i=3: idx=-4
            idx = - (i + 1)
            ego_status_list += [
                torch.tensor(agent_input.ego_statuses[idx].driving_command, dtype=torch.float32),
                torch.tensor(agent_input.ego_statuses[idx].ego_velocity, dtype=torch.float32),
                torch.tensor(agent_input.ego_statuses[idx].ego_acceleration, dtype=torch.float32),
            ]

        features["status_feature"] = torch.concatenate(
            ego_status_list
        )
        imgs = self._get_camera_feature(agent_input)
        features['img_3'] = imgs[0]
        features['img_2'] = imgs[1]
        features['img_1'] = imgs[2]
        # todo perspective box, map, cam
        # box

        # map

        # cam

        return features


class HydraDreamerWmTargetBuilder(AbstractTargetBuilder):
    def __init__(self, config: HydraDreamerConfig):
        super().__init__()
        self._config = config

    def get_unique_name(self) -> str:
        """Inherited, see superclass."""
        return "hydra_dreamer_wm_target"

    def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]:
        return {
            'img_gt': cat_flr_imgs(scene.get_agent_input().cameras[-1], self._config)
        }