|
|
|
import copy |
|
import logging |
|
import numpy as np |
|
from typing import List, Union |
|
import torch |
|
|
|
import detectron2.data.detection_utils as utils |
|
import detectron2.data.transforms as T |
|
from detectron2.config import configurable |
|
|
|
from .detection_utils import annotations_to_instances, transform_instance_annotations |
|
|
|
__all__ = [ |
|
"PointSupDatasetMapper", |
|
] |
|
|
|
|
|
class PointSupDatasetMapper: |
|
""" |
|
The callable currently does the following: |
|
1. Read the image from "file_name" |
|
2. Applies transforms to the image and annotations |
|
3. Prepare data and annotations to Tensor and :class:`Instances` |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
is_train: bool, |
|
*, |
|
augmentations: List[Union[T.Augmentation, T.Transform]], |
|
image_format: str, |
|
|
|
sample_points: int = 0, |
|
): |
|
""" |
|
NOTE: this interface is experimental. |
|
|
|
Args: |
|
is_train: whether it's used in training or inference |
|
augmentations: a list of augmentations or deterministic transforms to apply |
|
image_format: an image format supported by :func:`detection_utils.read_image`. |
|
sample_points: subsample points at each iteration |
|
""" |
|
|
|
self.is_train = is_train |
|
self.augmentations = T.AugmentationList(augmentations) |
|
self.image_format = image_format |
|
self.sample_points = sample_points |
|
|
|
logger = logging.getLogger(__name__) |
|
mode = "training" if is_train else "inference" |
|
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") |
|
logger.info(f"Point Augmentations used in {mode}: sample {sample_points} points") |
|
|
|
@classmethod |
|
def from_config(cls, cfg, is_train: bool = True): |
|
augs = utils.build_augmentation(cfg, is_train) |
|
if cfg.INPUT.CROP.ENABLED and is_train: |
|
raise ValueError("Crop augmentation not supported to point supervision.") |
|
|
|
ret = { |
|
"is_train": is_train, |
|
"augmentations": augs, |
|
"image_format": cfg.INPUT.FORMAT, |
|
"sample_points": cfg.INPUT.SAMPLE_POINTS, |
|
} |
|
|
|
return ret |
|
|
|
def __call__(self, dataset_dict): |
|
""" |
|
Args: |
|
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. |
|
Returns: |
|
dict: a format that builtin models in detectron2 accept |
|
""" |
|
dataset_dict = copy.deepcopy(dataset_dict) |
|
image = utils.read_image(dataset_dict["file_name"], format=self.image_format) |
|
utils.check_image_size(dataset_dict, image) |
|
|
|
aug_input = T.AugInput(image) |
|
transforms = self.augmentations(aug_input) |
|
image = aug_input.image |
|
|
|
image_shape = image.shape[:2] |
|
|
|
|
|
|
|
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) |
|
|
|
if not self.is_train: |
|
dataset_dict.pop("annotations", None) |
|
return dataset_dict |
|
|
|
if "annotations" in dataset_dict: |
|
|
|
|
|
|
|
|
|
|
|
for ann in dataset_dict["annotations"]: |
|
point_coords_wrt_image = np.array(ann["point_coords"]).astype(np.float) |
|
point_coords_wrt_image = point_coords_wrt_image + 0.5 |
|
ann["point_coords"] = point_coords_wrt_image |
|
|
|
annos = [ |
|
|
|
transform_instance_annotations( |
|
obj, |
|
transforms, |
|
image_shape, |
|
) |
|
for obj in dataset_dict.pop("annotations") |
|
if obj.get("iscrowd", 0) == 0 |
|
] |
|
instances = annotations_to_instances( |
|
annos, |
|
image_shape, |
|
sample_points=self.sample_points, |
|
) |
|
|
|
dataset_dict["instances"] = utils.filter_empty_instances(instances) |
|
return dataset_dict |
|
|