Spaces:
Running
on
Zero
Running
on
Zero
"""BDD100K tracking evaluator.""" | |
from __future__ import annotations | |
from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE | |
from vis4d.common.typing import MetricLogs | |
from vis4d.data.datasets.bdd100k import bdd100k_track_map | |
from ..scalabel.track import ScalabelTrackEvaluator | |
if SCALABEL_AVAILABLE and BDD100K_AVAILABLE: | |
from bdd100k.common.utils import load_bdd100k_config | |
from bdd100k.label.to_scalabel import bdd100k_to_scalabel | |
from scalabel.eval.detect import evaluate_det | |
from scalabel.eval.mot import acc_single_video_mot, evaluate_track | |
from scalabel.label.io import group_and_sort | |
else: | |
raise ImportError("scalabel or bdd100k is not installed.") | |
class BDD100KTrackEvaluator(ScalabelTrackEvaluator): | |
"""BDD100K 2D tracking evaluation class.""" | |
METRICS_DET = "Det" | |
METRICS_TRACK = "Track" | |
def __init__( | |
self, | |
annotation_path: str, | |
config_path: str = "box_track", | |
mask_threshold: float = 0.0, | |
) -> None: | |
"""Initialize the evaluator.""" | |
config = load_bdd100k_config(config_path) | |
super().__init__( | |
annotation_path=annotation_path, | |
config=config.scalabel, | |
mask_threshold=mask_threshold, | |
) | |
self.gt_frames = bdd100k_to_scalabel(self.gt_frames, config) | |
self.inverse_cat_map = {v: k for k, v in bdd100k_track_map.items()} | |
def __repr__(self) -> str: | |
"""Concise representation of the dataset evaluator.""" | |
return "BDD100K Tracking Evaluator" | |
def metrics(self) -> list[str]: | |
"""Supported metrics.""" | |
return [self.METRICS_DET, self.METRICS_TRACK] | |
def evaluate(self, metric: str) -> tuple[MetricLogs, str]: | |
"""Evaluate the dataset.""" | |
assert self.config is not None, "BDD100K config is not loaded." | |
metrics_log: MetricLogs = {} | |
short_description = "" | |
if metric == self.METRICS_DET: | |
det_results = evaluate_det( | |
self.gt_frames, | |
self.frames, | |
config=self.config, | |
nproc=0, | |
) | |
for metric_name, metric_value in det_results.summary().items(): | |
metrics_log[metric_name] = metric_value | |
short_description += str(det_results) + "\n" | |
if metric == self.METRICS_TRACK: | |
track_results = evaluate_track( | |
acc_single_video_mot, | |
gts=group_and_sort(self.gt_frames), | |
results=group_and_sort(self.frames), | |
config=self.config, | |
nproc=1, | |
) | |
for metric_name, metric_value in track_results.summary().items(): | |
metrics_log[metric_name] = metric_value | |
short_description += str(track_results) + "\n" | |
return metrics_log, short_description | |