File size: 4,445 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Image classification evaluator."""

from __future__ import annotations

import itertools

import numpy as np

from vis4d.common.array import array_to_numpy
from vis4d.common.typing import (
    ArrayLike,
    GenericFunc,
    MetricLogs,
    NDArrayI64,
    NDArrayNumber,
)
from vis4d.eval.base import Evaluator

from ..metrics.cls import accuracy


class ClassificationEvaluator(Evaluator):
    """Multi-class classification evaluator."""

    METRIC_CLASSIFICATION = "Cls"

    KEY_ACCURACY = "Acc@1"
    KEY_ACCURACY_TOP5 = "Acc@5"

    def __init__(self) -> None:
        """Initialize the classification evaluator."""
        super().__init__()
        self._metrics_list: list[dict[str, float]] = []

    @property
    def metrics(self) -> list[str]:
        """Supported metrics."""
        return [
            self.KEY_ACCURACY,
            self.KEY_ACCURACY_TOP5,
        ]

    def reset(self) -> None:
        """Reset evaluator for new round of evaluation."""
        self._metrics_list = []

    def _is_correct(
        self, pred: NDArrayNumber, target: NDArrayI64, top_k: int = 1
    ) -> bool:
        """Check if the prediction is correct for top-k.

        Args:
            pred (NDArrayNumber): Prediction logits, in shape (C, ).
            target (NDArrayI64): Target logits, in shape (1, ).
            top_k (int, optional): Top-k to check. Defaults to 1.

        Returns:
            bool: Whether the prediction is correct.
        """
        top_k = min(top_k, pred.shape[0])
        top_k_idx = np.argsort(pred)[-top_k:]
        return bool(np.any(top_k_idx == target))

    def process_batch(  # type: ignore # pylint: disable=arguments-differ
        self, prediction: ArrayLike, groundtruth: ArrayLike
    ):
        """Process a batch of predictions and groundtruths.

        Args:
            prediction (ArrayLike): Prediction, in shape (N, C).
            groundtruth (ArrayLike): Groundtruth, in shape (N, ).
        """
        pred = array_to_numpy(prediction, n_dims=None, dtype=np.float32)
        gt = array_to_numpy(groundtruth, n_dims=None, dtype=np.int64)
        for i in range(pred.shape[0]):
            self._metrics_list.append(
                {
                    "top1_correct": accuracy(pred[i], gt[i], top_k=1),
                    "top5_correct": accuracy(pred[i], gt[i], top_k=5),
                }
            )

    def gather(self, gather_func: GenericFunc) -> None:
        """Accumulate predictions across processes."""
        all_metrics = gather_func(self._metrics_list)
        if all_metrics is not None:
            self._metrics_list = list(itertools.chain(*all_metrics))

    def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
        """Evaluate predictions.

        Returns a dict containing the raw data and a
        short description string containing a readable result.

        Args:
            metric (str): Metric to use. See @property metric

        Returns:
            metric_data, description
            tuple containing the metric data (dict with metric name and value)
            as well as a short string with shortened information.

        Raises:
            RuntimeError: if no data has been registered to be evaluated.
            ValueError: if the metric is not supported.
        """
        if len(self._metrics_list) == 0:
            raise RuntimeError(
                """No data registered to calculate metric.
                   Register data using .process() first!"""
            )
        metric_data: MetricLogs = {}
        short_description = ""

        if metric == self.METRIC_CLASSIFICATION:
            # Top1 accuracy
            top1_correct = np.array(
                [metric["top1_correct"] for metric in self._metrics_list]
            )
            top1_acc = np.mean(top1_correct)
            metric_data[self.KEY_ACCURACY] = top1_acc
            short_description += f"Top1 Accuracy: {top1_acc:.4f}\n"

            # Top5 accuracy
            top5_correct = np.array(
                [metric["top5_correct"] for metric in self._metrics_list]
            )
            top5_acc = np.mean(top5_correct)
            metric_data[self.KEY_ACCURACY_TOP5] = top5_acc
            short_description += f"Top5 Accuracy: {top5_acc:.4f}\n"

        else:
            raise ValueError(
                f"Unsupported metric: {metric}"
            )  # pragma: no cover

        return metric_data, short_description