File size: 6,162 Bytes
03da825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import torch.nn.functional as F

import numpy as np
import kornia.geometry.transform as K
from kornia.augmentation import ColorJiggle, AugmentationSequential
from kornia.augmentation import RandomChannelShuffle


class FaceAugmentor:
    def __init__(self):
        self.color_jitter = AugmentationSequential(
            ColorJiggle(0.25, 0.3, 0.3, 0.3, p=1.),
            RandomChannelShuffle(),
        )

    def _random_normal(self, size=(1,), trunc_val=2.5, rnd_state=None, device='cpu'):
        if rnd_state is None:
            rnd_state = np.random
        len = np.array(size).prod()
        result = np.empty((len,), dtype=np.float32)

        for i in range(len):
            while True:
                x = rnd_state.normal()
                if x >= -trunc_val and x <= trunc_val:
                    break
            result[i] = (x / trunc_val)

        return torch.from_numpy(result.reshape(size)).to(device)

    def _get_warp_params(self, batch_size, img_size, device):
        # random warp
        batch_cell_size = np.random.choice([img_size // (2**i) for i in range(1, 4)], batch_size)
        batch_cell_count = img_size // batch_cell_size + 1

        batch_grid_points = [
            torch.linspace(0, img_size, cell_count, device=device) for cell_count in batch_cell_count
        ]
        batch_mapx = [
            torch.broadcast_to(grid_points, (cell_count, cell_count)).clone()
            for grid_points, cell_count in zip(batch_grid_points, batch_cell_count)
        ]
        batch_mapy = [x.t() for x in batch_mapx]

        batch_mapx_resized = []
        batch_mapy_resized = []

        for cell_size, cell_count, mapx, mapy in zip(batch_cell_size, batch_cell_count, batch_mapx, batch_mapy):
            half_cell_size = cell_size // 2
            mapx[1:-1, 1:-1] = mapx[1:-1, 1:-1] +\
                self._random_normal(
                    size=(cell_count-2, cell_count-2),
                    device=device
                ) * (cell_size*0.24)
            mapy[1:-1, 1:-1] = mapy[1:-1, 1:-1] +\
                self._random_normal(
                    size=(cell_count-2, cell_count-2),
                    device=device
                ) * (cell_size*0.24)
            img_size = int(img_size)
            cell_size = int(cell_size)
            mapx = F.interpolate(mapx.unsqueeze(0).unsqueeze(0), (img_size + cell_size,) * 2, mode='bilinear')[
                :, 0, half_cell_size: -half_cell_size, half_cell_size: -half_cell_size
            ]
            mapy = F.interpolate(mapy.unsqueeze(0).unsqueeze(0), (img_size + cell_size,) * 2, mode='bilinear')[
                :, 0, half_cell_size: -half_cell_size, half_cell_size: -half_cell_size
            ]

            batch_mapx_resized.append(mapx)
            batch_mapy_resized.append(mapy)

        batch_mapx_resized = torch.cat(batch_mapx_resized)
        batch_mapy_resized = torch.cat(batch_mapy_resized)

        return batch_mapx_resized, batch_mapy_resized

    def _mask(self, faces, apply_rnd_mask):
        """
        Notice that this masking function is designed specifically for EG3D canonical space (yaw and pitch equal to 0).
        If you change the coordinate system, change this too.
        """
        B = faces.shape[0]
        N = faces.shape[-1]
        for i in range(B):
            mask_percent = 0.25
            mask_size = int(mask_percent * N)
            mask = torch.zeros_like(faces[i: i + 1])
            ones = torch.ones((1, 3, N - mask_size * 2, N - mask_size * 2), device=mask.device)
            mask[:, :, mask_size: N - mask_size, mask_size: N - mask_size] = ones
            faces[i: i + 1, ...] = faces[i: i + 1, ...] * mask

            if apply_rnd_mask:
                for _ in range(5):
                    # 32x32 patch masking
                    mask = torch.ones_like(faces[i: i + 1])
                    zeros = torch.zeros((1, 3, 64, 64), device=mask.device)
                    x = np.random.randint(int(N * mask_percent), N // 2)
                    y = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent)))
                    mask[:, :, x: x + zeros.shape[-2], y: y + zeros.shape[-1]] = zeros
                    faces[i: i + 1, ...] = faces[i: i + 1, ...] * mask - (1 - mask)

        return faces

    def _random_zoom_in(self, faces):
        size = faces.shape[-1]
        zoom_size_h = int(size * (0.7 + np.random.rand() * 0.3))
        zoom_size_w = int(size * (0.7 + np.random.rand() * 0.3))
        faces = K.center_crop(faces, (zoom_size_h, zoom_size_w))
        faces = K.resize(faces, (size, size))

        return faces

    def _random_zoom_out(self, faces):
        size = faces.shape[-1]
        pad_h = np.random.randint(int(size * 0.2))
        pad_w = np.random.randint(int(size * 0.2))
        faces = F.pad(faces, (pad_h, pad_w), mode='constant')
        faces = K.resize(faces, (size, size))

        return faces
    
    def _random_color_patch(self, faces):
        mask_percent = 0.25
        B = faces.shape[0]
        N = faces.shape[-1]

        aug_faces = self.color_jitter(faces)

        for i in range(B):
            for _ in range(20):
                x = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent)))
                y = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent)))

                subfaces = faces[i, :, x: x + 128, y: y + 128]
                subfaces = self.color_jitter(subfaces)

                aug_faces[i, :, x: x + 128, y: y + 128] = subfaces
        return aug_faces

    @torch.no_grad()
    def __call__(self, faces, target_size, apply_color_aug=True, apply_rnd_mask=True, apply_rnd_zoom=True):
        if target_size is not None:
            faces = F.interpolate(faces, size=target_size)

        if apply_color_aug:
            faces = self._random_color_patch(faces)

        zoom_type = 2 if not apply_rnd_zoom else np.random.randint(3)
        if zoom_type == 0:
            faces = self._random_zoom_in(faces)
        elif zoom_type == 1:
            faces = self._random_zoom_out(faces)

        faces = self._mask(faces, apply_rnd_mask)

        return faces