Spaces:
Running
on
Zero
Running
on
Zero
"""SHIFT result writer.""" | |
from __future__ import annotations | |
import io | |
import itertools | |
import json | |
import os | |
from collections import defaultdict | |
import numpy as np | |
from PIL import Image | |
from vis4d.common.array import array_to_numpy | |
from vis4d.common.imports import SCALABEL_AVAILABLE | |
from vis4d.common.typing import ( | |
ArrayLike, | |
GenericFunc, | |
MetricLogs, | |
NDArrayNumber, | |
) | |
from vis4d.data.datasets.shift import shift_det_map | |
from vis4d.data.io import DataBackend, ZipBackend | |
from vis4d.eval.base import Evaluator | |
if SCALABEL_AVAILABLE: | |
from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d | |
from scalabel.label.typing import Dataset, Frame, Label | |
else: | |
raise ImportError("scalabel is not installed.") | |
class SHIFTMultitaskWriter(Evaluator): | |
"""SHIFT result writer for online evaluation.""" | |
inverse_cat_map = {v: k for k, v in shift_det_map.items()} | |
def __init__( | |
self, | |
output_dir: str, | |
submission_file: str = "submission.zip", | |
) -> None: | |
"""Creates a new writer. | |
Args: | |
output_dir (str): Output directory. | |
submission_file (str): Submission file name. Defaults to | |
"submission.zip". | |
""" | |
super().__init__() | |
assert submission_file.endswith( | |
".zip" | |
), "Submission file must be a zip file." | |
self.backend: DataBackend = ZipBackend() | |
self.output_path = os.path.join(output_dir, submission_file) | |
self.frames_det_2d: list[Frame] = [] | |
self.frames_det_3d: list[Frame] = [] | |
self.sample_counts: defaultdict[str, int] = defaultdict(int) | |
def _write_sem_mask( | |
self, sem_mask: NDArrayNumber, sample_name: str, video_name: str | |
) -> None: | |
"""Write semantic mask. | |
Args: | |
sem_mask (NDArrayNumber): Predicted semantic mask, shape (H, W). | |
sample_name (str): Sample name. | |
video_name (str): Video name. | |
""" | |
image = Image.fromarray(sem_mask.astype("uint8"), mode="L") | |
image_bytes = io.BytesIO() | |
image.save(image_bytes, format="PNG") | |
self.backend.set( | |
f"{self.output_path}/semseg/{video_name}/{sample_name}", | |
image_bytes.getvalue(), | |
mode="w", | |
) | |
def _write_depth( | |
self, depth_map: NDArrayNumber, sample_name: str, video_name: str | |
) -> None: | |
"""Write depth map. | |
Args: | |
depth_map (NDArrayNumber): Predicted depth map, shape (H, W). | |
sample_name (str): Sample name. | |
video_name (str): Video name. | |
""" | |
depth_map = np.clip(depth_map / 80.0 * 255.0, 0, 255) | |
image = Image.fromarray(depth_map.astype("uint8"), mode="L") | |
image_bytes = io.BytesIO() | |
image.save(image_bytes, format="PNG") | |
self.backend.set( | |
f"{self.output_path}/depth/{video_name}/{sample_name}", | |
image_bytes.getvalue(), | |
mode="w", | |
) | |
def _write_flow( | |
self, flow: NDArrayNumber, sample_name: str, video_name: str | |
) -> None: | |
"""Write semantic mask. | |
Args: | |
flow (NDArrayNumber): Predicted optical flow, shape (H, W, 2). | |
sample_name (str): Sample name. | |
video_name (str): Video name. | |
""" | |
raise NotImplementedError | |
def process_batch( | |
self, | |
frame_ids: list[int], | |
sample_names: list[str], | |
sequence_names: list[str], | |
pred_sem_mask: list[ArrayLike] | None = None, | |
pred_depth: list[ArrayLike] | None = None, | |
pred_flow: list[ArrayLike] | None = None, | |
pred_boxes2d: list[ArrayLike] | None = None, | |
pred_boxes2d_classes: list[ArrayLike] | None = None, | |
pred_boxes2d_scores: list[ArrayLike] | None = None, | |
pred_boxes2d_track_ids: list[ArrayLike] | None = None, | |
pred_instance_masks: list[ArrayLike] | None = None, | |
) -> None: | |
"""Process SHIFT results. | |
You can omit some of the predictions if they are not used. | |
Args: | |
frame_ids (list[int]): Frame IDs. | |
sample_names (list[str]): Sample names. | |
sequence_names (list[str]): Sequence names. | |
pred_sem_mask (list[ArrayLike], optional): Predicted semantic | |
masks, each in shape (C, H, W) or (H, W). Defaults to None. | |
pred_depth (list[ArrayLike], optional): Predicted depth maps, | |
each in shape (H, W), with meter unit. Defaults to None. | |
pred_flow (list[ArrayLike], optional): Predicted optical flows, | |
each in shape (H, W, 2). Defaults to None. | |
pred_boxes2d (list[ArrayLike], optional): Predicted 2D boxes, | |
each in shape (N, 4). Defaults to None. | |
pred_boxes2d_classes (list[ArrayLike], optional): Predicted | |
2D box classes, each in shape (N,). Defaults to None. | |
pred_boxes2d_scores (list[ArrayLike], optional): Predicted | |
2D box scores, each in shape (N,). Defaults to None. | |
pred_boxes2d_track_ids (list[ArrayLike], optional): Predicted | |
2D box track IDs, each in shape (N,). Defaults to None. | |
pred_instance_masks (list[ArrayLike], optional): Predicted | |
instance masks, each in shape (N, H, W). Defaults to None. | |
""" | |
for i, (frame_id, sample_name, sequence_name) in enumerate( | |
zip(frame_ids, sample_names, sequence_names) | |
): | |
if pred_sem_mask is not None: | |
sem_mask_ = array_to_numpy( | |
pred_sem_mask[i], | |
n_dims=None, | |
dtype=np.float32, | |
) | |
if len(sem_mask_.shape) == 3: | |
sem_mask = sem_mask_.argmax(axis=0) | |
else: | |
sem_mask = sem_mask_.astype(np.uint8) | |
semseg_filename = sample_name.replace(".jpg", ".png").replace( | |
"img", "semseg" | |
) | |
self._write_sem_mask(sem_mask, semseg_filename, sequence_name) | |
self.sample_counts["semseg"] += 1 | |
if pred_depth is not None: | |
depth = array_to_numpy( | |
pred_depth[i], n_dims=None, dtype=np.float32 | |
) | |
depth_filename = sample_name.replace(".jpg", ".png").replace( | |
"img", "depth" | |
) | |
self._write_depth(depth, depth_filename, sequence_name) | |
self.sample_counts["depth"] += 1 | |
if pred_flow is not None: | |
flow = array_to_numpy( | |
pred_flow[i], n_dims=None, dtype=np.float32 | |
) | |
self._write_flow(flow, sample_name, sequence_name) | |
self.sample_counts["flow"] += 1 | |
if ( | |
pred_boxes2d is not None | |
and pred_boxes2d_classes is not None | |
and pred_boxes2d_scores is not None | |
): | |
labels = [] | |
if pred_instance_masks: | |
masks = array_to_numpy( | |
pred_instance_masks[i], n_dims=None, dtype=np.float32 | |
) | |
if pred_boxes2d_track_ids: | |
track_ids = array_to_numpy( | |
pred_boxes2d_track_ids[i], | |
n_dims=None, | |
dtype=np.int64, | |
) | |
for box, score, class_id in zip( | |
pred_boxes2d[i], | |
pred_boxes2d_scores[i], | |
pred_boxes2d_classes[i], | |
): | |
box2d = xyxy_to_box2d(*box.tolist()) | |
if pred_instance_masks: | |
rle = mask_to_rle( | |
(masks[class_id] > 0.0).astype(np.uint8) | |
) | |
else: | |
rle = None | |
if pred_boxes2d_track_ids: | |
track_id = str(int(track_ids[0])) | |
else: | |
track_id = None | |
label = Label( | |
box2d=box2d, | |
category=( | |
self.inverse_cat_map[int(class_id)] | |
if self.inverse_cat_map != {} | |
else str(class_id) | |
), | |
score=float(score), | |
rle=rle, | |
id=track_id, | |
) | |
labels.append(label) | |
frame = Frame( | |
name=sample_name, | |
videoName=sequence_name, | |
frameIndex=frame_id, | |
labels=labels, | |
) | |
self.frames_det_2d.append(frame) | |
self.sample_counts["det_2d"] += 1 | |
def gather(self, gather_func: GenericFunc) -> None: # pragma: no cover | |
"""Gather variables in case of distributed setting (if needed). | |
Args: | |
gather_func (Callable[[Any], Any]): Gather function. | |
""" | |
all_preds = gather_func(self.frames_det_2d) | |
if all_preds is not None: | |
self.frames_det_2d = list(itertools.chain(*all_preds)) | |
def evaluate(self, metric: str) -> tuple[MetricLogs, str]: | |
"""No evaluation locally.""" | |
return {}, "No evaluation locally." | |
def save(self, metric: str, output_dir: str) -> None: | |
"""Save scalabel output to zip file. | |
Raises: | |
ValueError: If the number of samples in each category is not the | |
same. | |
""" | |
# Check if the sample counts are correct | |
equal_size = True | |
for key in self.sample_counts: | |
if self.sample_counts[key] != len(self.frames_det_2d): | |
equal_size = False | |
break | |
if not equal_size: | |
raise ValueError( | |
"The number of samples in each category is not the same." | |
) | |
# Save the 2D detection results | |
if len(self.frames_det_2d) > 0: | |
ds = Dataset(frames=self.frames_det_2d, groups=None, config=None) | |
ds_bytes = json.dumps(ds.dict()).encode("utf-8") | |
self.backend.set( | |
f"{self.output_path}/det_2d.json", ds_bytes, mode="w" | |
) | |
self.backend.close() | |
print(f"Saved the submission file at {self.output_path}.") | |