File size: 4,245 Bytes
3b5fc39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# ------------------------------------------------------------------------
# Modified from SeqFormer (https://github.com/wjf5203/SeqFormer)
# ------------------------------------------------------------------------
# Modified from STEm-Seg (https://github.com/sabarim/STEm-Seg)
# ------------------------------------------------------------------------


import imgaug
import imgaug.augmenters as iaa
import numpy as np

from datetime import datetime

from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage


class ImageToSeqAugmenter(object):
    def __init__(self, perspective=True, affine=True, motion_blur=True,
                 brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12,
                 scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20),
                 motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5):

        self.basic_augmenter = iaa.SomeOf((1, None), [
                iaa.Add(brightness_range),
                iaa.AddToHueAndSaturation(hue_saturation_range)
            ]
        )

        transforms = []
        if perspective:
            transforms.append(iaa.PerspectiveTransform(perspective_magnitude))
        if affine:
            transforms.append(iaa.Affine(scale=scale_range,
                                         translate_percent=translate_range,
                                         rotate=rotation_range,
                                         order=1,  # cv2.INTER_LINEAR
                                         backend='auto'))
        transforms = iaa.Sequential(transforms)
        transforms = [transforms]

        if motion_blur:
            blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf(
                [
                    iaa.MotionBlur(ksize)
                    for ksize in motion_blur_kernel_sizes
                ]
            ))
            transforms.append(blur)

        self.frame_shift_augmenter = iaa.Sequential(transforms)

    @staticmethod
    def condense_masks(instance_masks):
        condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8)
        for instance_id, mask in enumerate(instance_masks, 1):
            condensed_mask = np.where(mask, instance_id, condensed_mask)

        return condensed_mask

    @staticmethod
    def expand_masks(condensed_mask, num_instances):
        return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)]

    def __call__(self, image, masks=None, boxes=None):
        det_augmenter = self.frame_shift_augmenter.to_deterministic()


        if masks is not None:
            masks_np, is_binary_mask = [], []
            boxs_np = []

            for mask in masks:
                
                if isinstance(mask, np.ndarray):
                    masks_np.append(mask.astype(np.bool))
                    is_binary_mask.append(False)
                else:
                    raise ValueError("Invalid mask type: {}".format(type(mask)))

            num_instances = len(masks_np)
            masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2])
            # boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2])

            seed = int(datetime.now().strftime('%M%S%f')[-8:])
            imgaug.seed(seed)
            aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np)
            imgaug.seed(seed)
            invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2)
            aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances)
            # aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image()
            aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)]
            return aug_image, aug_masks #, aug_boxes.to_xyxy_array()

        else:
            masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])]
            aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks)
            return aug_image, invalid_pts_mask.get_arr() == 0