VRIS_vip / my_datasets /image_to_seq_augmenter.py
dianecy's picture
Add files using upload-large-folder tool
3b5fc39 verified
raw
history blame
4.25 kB
# ------------------------------------------------------------------------
# Modified from SeqFormer (https://github.com/wjf5203/SeqFormer)
# ------------------------------------------------------------------------
# Modified from STEm-Seg (https://github.com/sabarim/STEm-Seg)
# ------------------------------------------------------------------------
import imgaug
import imgaug.augmenters as iaa
import numpy as np
from datetime import datetime
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
class ImageToSeqAugmenter(object):
def __init__(self, perspective=True, affine=True, motion_blur=True,
brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12,
scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20),
motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5):
self.basic_augmenter = iaa.SomeOf((1, None), [
iaa.Add(brightness_range),
iaa.AddToHueAndSaturation(hue_saturation_range)
]
)
transforms = []
if perspective:
transforms.append(iaa.PerspectiveTransform(perspective_magnitude))
if affine:
transforms.append(iaa.Affine(scale=scale_range,
translate_percent=translate_range,
rotate=rotation_range,
order=1, # cv2.INTER_LINEAR
backend='auto'))
transforms = iaa.Sequential(transforms)
transforms = [transforms]
if motion_blur:
blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf(
[
iaa.MotionBlur(ksize)
for ksize in motion_blur_kernel_sizes
]
))
transforms.append(blur)
self.frame_shift_augmenter = iaa.Sequential(transforms)
@staticmethod
def condense_masks(instance_masks):
condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8)
for instance_id, mask in enumerate(instance_masks, 1):
condensed_mask = np.where(mask, instance_id, condensed_mask)
return condensed_mask
@staticmethod
def expand_masks(condensed_mask, num_instances):
return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)]
def __call__(self, image, masks=None, boxes=None):
det_augmenter = self.frame_shift_augmenter.to_deterministic()
if masks is not None:
masks_np, is_binary_mask = [], []
boxs_np = []
for mask in masks:
if isinstance(mask, np.ndarray):
masks_np.append(mask.astype(np.bool))
is_binary_mask.append(False)
else:
raise ValueError("Invalid mask type: {}".format(type(mask)))
num_instances = len(masks_np)
masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2])
# boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2])
seed = int(datetime.now().strftime('%M%S%f')[-8:])
imgaug.seed(seed)
aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np)
imgaug.seed(seed)
invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2)
aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances)
# aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image()
aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)]
return aug_image, aug_masks #, aug_boxes.to_xyxy_array()
else:
masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])]
aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks)
return aug_image, invalid_pts_mask.get_arr() == 0