File size: 3,359 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""BDD100K segmentation evaluator."""

from __future__ import annotations

import itertools
from collections.abc import Callable
from typing import Any

import numpy as np

from vis4d.common.array import array_to_numpy
from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE
from vis4d.common.typing import ArrayLike, MetricLogs
from vis4d.data.datasets.bdd100k import bdd100k_seg_map

from ..base import Evaluator

if SCALABEL_AVAILABLE and BDD100K_AVAILABLE:
    from bdd100k.common.utils import load_bdd100k_config
    from bdd100k.label.to_scalabel import bdd100k_to_scalabel
    from scalabel.eval.sem_seg import evaluate_sem_seg
    from scalabel.label.io import load
    from scalabel.label.transforms import mask_to_rle
    from scalabel.label.typing import Frame, Label
else:
    raise ImportError("scalabel or bdd100k is not installed.")


class BDD100KSegEvaluator(Evaluator):
    """BDD100K segmentation evaluation class."""

    inverse_seg_map = {v: k for k, v in bdd100k_seg_map.items()}

    def __init__(self, annotation_path: str) -> None:
        """Initialize the evaluator."""
        super().__init__()
        self.annotation_path = annotation_path
        self.frames: list[Frame] = []

        bdd100k_anns = load(annotation_path)
        frames = bdd100k_anns.frames
        self.config = load_bdd100k_config("sem_seg")
        self.gt_frames = bdd100k_to_scalabel(frames, self.config)

        self.reset()

    def __repr__(self) -> str:
        """Concise representation of the dataset evaluator."""
        return "BDD100K Segmentation Evaluator"

    @property
    def metrics(self) -> list[str]:
        """Supported metrics."""
        return ["sem_seg"]

    def gather(  # type: ignore # pragma: no cover
        self, gather_func: Callable[[Any], Any]
    ) -> None:
        """Gather variables in case of distributed setting (if needed).

        Args:
            gather_func (Callable[[Any], Any]): Gather function.
        """
        all_preds = gather_func(self.frames)
        if all_preds is not None:
            self.frames = list(itertools.chain(*all_preds))

    def reset(self) -> None:
        """Reset the evaluator."""
        self.frames = []

    def process_batch(
        self, data_names: list[str], masks_list: list[ArrayLike]
    ) -> None:
        """Process tracking results."""
        masks_numpy = [array_to_numpy(m, None) for m in masks_list]  # to numpy
        for data_name, masks in zip(data_names, masks_numpy):
            labels = []
            for i, class_id in enumerate(np.unique(masks)):
                label = Label(
                    rle=mask_to_rle((masks == class_id).astype(np.uint8)),
                    category=self.inverse_seg_map[int(class_id)],
                    id=str(i),
                )
                labels.append(label)
            frame = Frame(name=data_name, labels=labels)
            self.frames.append(frame)

    def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
        """Evaluate the dataset."""
        if metric == "sem_seg":
            results = evaluate_sem_seg(
                ann_frames=self.gt_frames,
                pred_frames=self.frames,
                config=self.config.scalabel,
                nproc=0,
            )
        else:
            raise NotImplementedError

        return {}, str(results)