File size: 3,510 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""NuScenes multi-sensor video dataset."""

from __future__ import annotations

import json

import numpy as np

from vis4d.common.typing import ArgsType, DictStrAny, NDArrayF32, NDArrayI64
from vis4d.data.typing import DictData

from .nuscenes import NuScenes, nuscenes_class_map


class NuScenesDetection(NuScenes):
    """NuScenes detection dataset."""

    def __init__(
        self,
        pure_detection: str,
        score_thres: float = 0.05,
        **kwargs: ArgsType,
    ) -> None:
        """Creates an instance of the class."""
        self.pure_detection = pure_detection
        self.score_thres = score_thres

        with open(self.pure_detection, encoding="utf-8") as f:
            self.predictions = json.load(f)

        super().__init__(**kwargs)

    def __repr__(self) -> str:
        """Concise representation of the dataset."""
        return (
            f"NuScenesDetection {self.version} {self.split} using "
            + f"{self.pure_detection}"
        )

    def _load_pred(
        self, preds: list[DictStrAny]
    ) -> tuple[NDArrayF32, NDArrayI64, NDArrayF32, NDArrayF32]:
        """Load nuscenes format prediction."""
        boxes3d = np.empty((1, 10), dtype=np.float32)[1:]
        boxes3d_classes = np.empty((1,), dtype=np.int64)[1:]
        boxes3d_scores = np.empty((1,), dtype=np.float32)[1:]
        boxes3d_velocities = np.empty((1, 3), dtype=np.float32)[1:]

        for pred in preds:
            if pred["detection_name"] not in nuscenes_class_map:
                continue

            if float(pred["detection_score"]) <= self.score_thres:
                continue

            boxes3d = np.concatenate(
                [
                    boxes3d,
                    np.array(
                        [
                            [
                                *pred["translation"],
                                *pred["size"],
                                *pred["rotation"],
                            ]
                        ],
                        dtype=np.float32,
                    ),
                ]
            )
            boxes3d_classes = np.concatenate(
                [
                    boxes3d_classes,
                    np.array(
                        [nuscenes_class_map[pred["detection_name"]]],
                        dtype=np.int64,
                    ),
                ]
            )
            boxes3d_scores = np.concatenate(
                [
                    boxes3d_scores,
                    np.array([pred["detection_score"]], dtype=np.float32),
                ]
            )
            boxes3d_velocities = np.concatenate(
                [
                    boxes3d_velocities,
                    np.array([[*pred["velocity"], 0]], dtype=np.float32),
                ]
            )

        return boxes3d, boxes3d_classes, boxes3d_scores, boxes3d_velocities

    def __getitem__(self, idx: int) -> DictData:
        """Get single sample.

        Args:
            idx (int): Index of sample.

        Returns:
            DictData: sample at index in Vis4D input format.
        """
        data_dict = super().__getitem__(idx)

        (
            data_dict["LIDAR_TOP"]["pred_boxes3d"],
            data_dict["LIDAR_TOP"]["pred_boxes3d_classes"],
            data_dict["LIDAR_TOP"]["pred_boxes3d_scores"],
            data_dict["LIDAR_TOP"]["pred_boxes3d_velocities"],
        ) = self._load_pred(self.predictions["results"][data_dict["token"]])

        return data_dict