File size: 5,489 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import numpy as np
from mmcv.transforms import BaseTransform

from mmdet.registry import TRANSFORMS


@TRANSFORMS.register_module()
class InstaBoost(BaseTransform):
    r"""Data augmentation method in `InstaBoost: Boosting Instance
    Segmentation Via Probability Map Guided Copy-Pasting
    <https://arxiv.org/abs/1908.07801>`_.

    Refer to https://github.com/GothicAi/Instaboost for implementation details.


    Required Keys:

    - img (np.uint8)
    - instances

    Modified Keys:

    - img (np.uint8)
    - instances

    Args:
        action_candidate (tuple): Action candidates. "normal", "horizontal", \
            "vertical", "skip" are supported. Defaults to ('normal', \
            'horizontal', 'skip').
        action_prob (tuple): Corresponding action probabilities. Should be \
            the same length as action_candidate. Defaults to (1, 0, 0).
        scale (tuple): (min scale, max scale). Defaults to (0.8, 1.2).
        dx (int): The maximum x-axis shift will be (instance width) / dx.
            Defaults to 15.
        dy (int): The maximum y-axis shift will be (instance height) / dy.
            Defaults to 15.
        theta (tuple): (min rotation degree, max rotation degree). \
            Defaults to (-1, 1).
        color_prob (float): Probability of images for color augmentation.
            Defaults to 0.5.
        hflag (bool): Whether to use heatmap guided. Defaults to False.
        aug_ratio (float): Probability of applying this transformation. \
            Defaults to 0.5.
    """

    def __init__(self,
                 action_candidate: tuple = ('normal', 'horizontal', 'skip'),
                 action_prob: tuple = (1, 0, 0),
                 scale: tuple = (0.8, 1.2),
                 dx: int = 15,
                 dy: int = 15,
                 theta: tuple = (-1, 1),
                 color_prob: float = 0.5,
                 hflag: bool = False,
                 aug_ratio: float = 0.5) -> None:

        import matplotlib
        import matplotlib.pyplot as plt
        default_backend = plt.get_backend()

        try:
            import instaboostfast as instaboost
        except ImportError:
            raise ImportError(
                'Please run "pip install instaboostfast" '
                'to install instaboostfast first for instaboost augmentation.')

        # instaboost will modify the default backend
        # and cause visualization to fail.
        matplotlib.use(default_backend)

        self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
                                               scale, dx, dy, theta,
                                               color_prob, hflag)
        self.aug_ratio = aug_ratio

    def _load_anns(self, results: dict) -> Tuple[list, list]:
        """Convert raw anns to instaboost expected input format."""
        anns = []
        ignore_anns = []
        for instance in results['instances']:
            label = instance['bbox_label']
            bbox = instance['bbox']
            mask = instance['mask']
            x1, y1, x2, y2 = bbox
            # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
            bbox = [x1, y1, x2 - x1, y2 - y1]

            if instance['ignore_flag'] == 0:
                anns.append({
                    'category_id': label,
                    'segmentation': mask,
                    'bbox': bbox
                })
            else:
                # Ignore instances without data augmentation
                ignore_anns.append(instance)
        return anns, ignore_anns

    def _parse_anns(self, results: dict, anns: list, ignore_anns: list,
                    img: np.ndarray) -> dict:
        """Restore the result of instaboost processing to the original anns
        format."""
        instances = []
        for ann in anns:
            x1, y1, w, h = ann['bbox']
            # TODO: more essential bug need to be fixed in instaboost
            if w <= 0 or h <= 0:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]
            instances.append(
                dict(
                    bbox=bbox,
                    bbox_label=ann['category_id'],
                    mask=ann['segmentation'],
                    ignore_flag=0))

        instances.extend(ignore_anns)
        results['img'] = img
        results['instances'] = instances
        return results

    def transform(self, results) -> dict:
        """The transform function."""
        img = results['img']
        ori_type = img.dtype
        if 'instances' not in results or len(results['instances']) == 0:
            return results

        anns, ignore_anns = self._load_anns(results)
        if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
            try:
                import instaboostfast as instaboost
            except ImportError:
                raise ImportError('Please run "pip install instaboostfast" '
                                  'to install instaboostfast first.')
            anns, img = instaboost.get_new_data(
                anns, img.astype(np.uint8), self.cfg, background=None)

        results = self._parse_anns(results, anns, ignore_anns,
                                   img.astype(ori_type))
        return results

    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'(aug_ratio={self.aug_ratio})'
        return repr_str