3D-MOOD / vis4d /eval /shift /multitask_writer.py
RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""SHIFT result writer."""
from __future__ import annotations
import io
import itertools
import json
import os
from collections import defaultdict
import numpy as np
from PIL import Image
from vis4d.common.array import array_to_numpy
from vis4d.common.imports import SCALABEL_AVAILABLE
from vis4d.common.typing import (
ArrayLike,
GenericFunc,
MetricLogs,
NDArrayNumber,
)
from vis4d.data.datasets.shift import shift_det_map
from vis4d.data.io import DataBackend, ZipBackend
from vis4d.eval.base import Evaluator
if SCALABEL_AVAILABLE:
from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d
from scalabel.label.typing import Dataset, Frame, Label
else:
raise ImportError("scalabel is not installed.")
class SHIFTMultitaskWriter(Evaluator):
"""SHIFT result writer for online evaluation."""
inverse_cat_map = {v: k for k, v in shift_det_map.items()}
def __init__(
self,
output_dir: str,
submission_file: str = "submission.zip",
) -> None:
"""Creates a new writer.
Args:
output_dir (str): Output directory.
submission_file (str): Submission file name. Defaults to
"submission.zip".
"""
super().__init__()
assert submission_file.endswith(
".zip"
), "Submission file must be a zip file."
self.backend: DataBackend = ZipBackend()
self.output_path = os.path.join(output_dir, submission_file)
self.frames_det_2d: list[Frame] = []
self.frames_det_3d: list[Frame] = []
self.sample_counts: defaultdict[str, int] = defaultdict(int)
def _write_sem_mask(
self, sem_mask: NDArrayNumber, sample_name: str, video_name: str
) -> None:
"""Write semantic mask.
Args:
sem_mask (NDArrayNumber): Predicted semantic mask, shape (H, W).
sample_name (str): Sample name.
video_name (str): Video name.
"""
image = Image.fromarray(sem_mask.astype("uint8"), mode="L")
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
self.backend.set(
f"{self.output_path}/semseg/{video_name}/{sample_name}",
image_bytes.getvalue(),
mode="w",
)
def _write_depth(
self, depth_map: NDArrayNumber, sample_name: str, video_name: str
) -> None:
"""Write depth map.
Args:
depth_map (NDArrayNumber): Predicted depth map, shape (H, W).
sample_name (str): Sample name.
video_name (str): Video name.
"""
depth_map = np.clip(depth_map / 80.0 * 255.0, 0, 255)
image = Image.fromarray(depth_map.astype("uint8"), mode="L")
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
self.backend.set(
f"{self.output_path}/depth/{video_name}/{sample_name}",
image_bytes.getvalue(),
mode="w",
)
def _write_flow(
self, flow: NDArrayNumber, sample_name: str, video_name: str
) -> None:
"""Write semantic mask.
Args:
flow (NDArrayNumber): Predicted optical flow, shape (H, W, 2).
sample_name (str): Sample name.
video_name (str): Video name.
"""
raise NotImplementedError
def process_batch(
self,
frame_ids: list[int],
sample_names: list[str],
sequence_names: list[str],
pred_sem_mask: list[ArrayLike] | None = None,
pred_depth: list[ArrayLike] | None = None,
pred_flow: list[ArrayLike] | None = None,
pred_boxes2d: list[ArrayLike] | None = None,
pred_boxes2d_classes: list[ArrayLike] | None = None,
pred_boxes2d_scores: list[ArrayLike] | None = None,
pred_boxes2d_track_ids: list[ArrayLike] | None = None,
pred_instance_masks: list[ArrayLike] | None = None,
) -> None:
"""Process SHIFT results.
You can omit some of the predictions if they are not used.
Args:
frame_ids (list[int]): Frame IDs.
sample_names (list[str]): Sample names.
sequence_names (list[str]): Sequence names.
pred_sem_mask (list[ArrayLike], optional): Predicted semantic
masks, each in shape (C, H, W) or (H, W). Defaults to None.
pred_depth (list[ArrayLike], optional): Predicted depth maps,
each in shape (H, W), with meter unit. Defaults to None.
pred_flow (list[ArrayLike], optional): Predicted optical flows,
each in shape (H, W, 2). Defaults to None.
pred_boxes2d (list[ArrayLike], optional): Predicted 2D boxes,
each in shape (N, 4). Defaults to None.
pred_boxes2d_classes (list[ArrayLike], optional): Predicted
2D box classes, each in shape (N,). Defaults to None.
pred_boxes2d_scores (list[ArrayLike], optional): Predicted
2D box scores, each in shape (N,). Defaults to None.
pred_boxes2d_track_ids (list[ArrayLike], optional): Predicted
2D box track IDs, each in shape (N,). Defaults to None.
pred_instance_masks (list[ArrayLike], optional): Predicted
instance masks, each in shape (N, H, W). Defaults to None.
"""
for i, (frame_id, sample_name, sequence_name) in enumerate(
zip(frame_ids, sample_names, sequence_names)
):
if pred_sem_mask is not None:
sem_mask_ = array_to_numpy(
pred_sem_mask[i],
n_dims=None,
dtype=np.float32,
)
if len(sem_mask_.shape) == 3:
sem_mask = sem_mask_.argmax(axis=0)
else:
sem_mask = sem_mask_.astype(np.uint8)
semseg_filename = sample_name.replace(".jpg", ".png").replace(
"img", "semseg"
)
self._write_sem_mask(sem_mask, semseg_filename, sequence_name)
self.sample_counts["semseg"] += 1
if pred_depth is not None:
depth = array_to_numpy(
pred_depth[i], n_dims=None, dtype=np.float32
)
depth_filename = sample_name.replace(".jpg", ".png").replace(
"img", "depth"
)
self._write_depth(depth, depth_filename, sequence_name)
self.sample_counts["depth"] += 1
if pred_flow is not None:
flow = array_to_numpy(
pred_flow[i], n_dims=None, dtype=np.float32
)
self._write_flow(flow, sample_name, sequence_name)
self.sample_counts["flow"] += 1
if (
pred_boxes2d is not None
and pred_boxes2d_classes is not None
and pred_boxes2d_scores is not None
):
labels = []
if pred_instance_masks:
masks = array_to_numpy(
pred_instance_masks[i], n_dims=None, dtype=np.float32
)
if pred_boxes2d_track_ids:
track_ids = array_to_numpy(
pred_boxes2d_track_ids[i],
n_dims=None,
dtype=np.int64,
)
for box, score, class_id in zip(
pred_boxes2d[i],
pred_boxes2d_scores[i],
pred_boxes2d_classes[i],
):
box2d = xyxy_to_box2d(*box.tolist())
if pred_instance_masks:
rle = mask_to_rle(
(masks[class_id] > 0.0).astype(np.uint8)
)
else:
rle = None
if pred_boxes2d_track_ids:
track_id = str(int(track_ids[0]))
else:
track_id = None
label = Label(
box2d=box2d,
category=(
self.inverse_cat_map[int(class_id)]
if self.inverse_cat_map != {}
else str(class_id)
),
score=float(score),
rle=rle,
id=track_id,
)
labels.append(label)
frame = Frame(
name=sample_name,
videoName=sequence_name,
frameIndex=frame_id,
labels=labels,
)
self.frames_det_2d.append(frame)
self.sample_counts["det_2d"] += 1
def gather(self, gather_func: GenericFunc) -> None: # pragma: no cover
"""Gather variables in case of distributed setting (if needed).
Args:
gather_func (Callable[[Any], Any]): Gather function.
"""
all_preds = gather_func(self.frames_det_2d)
if all_preds is not None:
self.frames_det_2d = list(itertools.chain(*all_preds))
def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
"""No evaluation locally."""
return {}, "No evaluation locally."
def save(self, metric: str, output_dir: str) -> None:
"""Save scalabel output to zip file.
Raises:
ValueError: If the number of samples in each category is not the
same.
"""
# Check if the sample counts are correct
equal_size = True
for key in self.sample_counts:
if self.sample_counts[key] != len(self.frames_det_2d):
equal_size = False
break
if not equal_size:
raise ValueError(
"The number of samples in each category is not the same."
)
# Save the 2D detection results
if len(self.frames_det_2d) > 0:
ds = Dataset(frames=self.frames_det_2d, groups=None, config=None)
ds_bytes = json.dumps(ds.dict()).encode("utf-8")
self.backend.set(
f"{self.output_path}/det_2d.json", ds_bytes, mode="w"
)
self.backend.close()
print(f"Saved the submission file at {self.output_path}.")