File size: 1,465 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""SHIFT depth estimation evaluator."""

from __future__ import annotations

from vis4d.common.typing import NDArrayNumber

from ..common import DepthEvaluator


def apply_crop(depth: NDArrayNumber) -> NDArrayNumber:
    """Apply crop to depth map to match SHIFT evaluation."""
    return depth[..., 0:740, :]


class SHIFTDepthEvaluator(DepthEvaluator):
    """SHIFT depth estimation evaluation class."""

    def __init__(self, use_eval_crop: bool = True) -> None:
        """Initialize the evaluator.

        Args:
            use_eval_crop (bool): Whether to use the evaluation crop.
                Default: True.
        """
        super().__init__(min_depth=0.01, max_depth=80.0)
        self.use_eval_crop = use_eval_crop

    def __repr__(self) -> str:
        """Concise representation of the dataset evaluator."""
        return "SHIFT Depth Estimation Evaluator"

    def process_batch(  # type: ignore # pylint: disable=arguments-differ
        self, prediction: NDArrayNumber, groundtruth: NDArrayNumber
    ) -> None:
        """Process sample and update confusion matrix.

        Args:
            prediction: Predictions of shape (N, H, W).
            groundtruth: Groundtruth of shape (N, H, W).
        """
        if self.use_eval_crop:
            prediction = apply_crop(prediction)
            groundtruth = apply_crop(groundtruth)
        print(prediction.shape, groundtruth.shape)
        super().process_batch(prediction, groundtruth)