|
import torch |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
import kornia.geometry.transform as K |
|
from kornia.augmentation import ColorJiggle, AugmentationSequential |
|
from kornia.augmentation import RandomChannelShuffle |
|
|
|
|
|
class FaceAugmentor: |
|
def __init__(self): |
|
self.color_jitter = AugmentationSequential( |
|
ColorJiggle(0.25, 0.3, 0.3, 0.3, p=1.), |
|
RandomChannelShuffle(), |
|
) |
|
|
|
def _random_normal(self, size=(1,), trunc_val=2.5, rnd_state=None, device='cpu'): |
|
if rnd_state is None: |
|
rnd_state = np.random |
|
len = np.array(size).prod() |
|
result = np.empty((len,), dtype=np.float32) |
|
|
|
for i in range(len): |
|
while True: |
|
x = rnd_state.normal() |
|
if x >= -trunc_val and x <= trunc_val: |
|
break |
|
result[i] = (x / trunc_val) |
|
|
|
return torch.from_numpy(result.reshape(size)).to(device) |
|
|
|
def _get_warp_params(self, batch_size, img_size, device): |
|
|
|
batch_cell_size = np.random.choice([img_size // (2**i) for i in range(1, 4)], batch_size) |
|
batch_cell_count = img_size // batch_cell_size + 1 |
|
|
|
batch_grid_points = [ |
|
torch.linspace(0, img_size, cell_count, device=device) for cell_count in batch_cell_count |
|
] |
|
batch_mapx = [ |
|
torch.broadcast_to(grid_points, (cell_count, cell_count)).clone() |
|
for grid_points, cell_count in zip(batch_grid_points, batch_cell_count) |
|
] |
|
batch_mapy = [x.t() for x in batch_mapx] |
|
|
|
batch_mapx_resized = [] |
|
batch_mapy_resized = [] |
|
|
|
for cell_size, cell_count, mapx, mapy in zip(batch_cell_size, batch_cell_count, batch_mapx, batch_mapy): |
|
half_cell_size = cell_size // 2 |
|
mapx[1:-1, 1:-1] = mapx[1:-1, 1:-1] +\ |
|
self._random_normal( |
|
size=(cell_count-2, cell_count-2), |
|
device=device |
|
) * (cell_size*0.24) |
|
mapy[1:-1, 1:-1] = mapy[1:-1, 1:-1] +\ |
|
self._random_normal( |
|
size=(cell_count-2, cell_count-2), |
|
device=device |
|
) * (cell_size*0.24) |
|
img_size = int(img_size) |
|
cell_size = int(cell_size) |
|
mapx = F.interpolate(mapx.unsqueeze(0).unsqueeze(0), (img_size + cell_size,) * 2, mode='bilinear')[ |
|
:, 0, half_cell_size: -half_cell_size, half_cell_size: -half_cell_size |
|
] |
|
mapy = F.interpolate(mapy.unsqueeze(0).unsqueeze(0), (img_size + cell_size,) * 2, mode='bilinear')[ |
|
:, 0, half_cell_size: -half_cell_size, half_cell_size: -half_cell_size |
|
] |
|
|
|
batch_mapx_resized.append(mapx) |
|
batch_mapy_resized.append(mapy) |
|
|
|
batch_mapx_resized = torch.cat(batch_mapx_resized) |
|
batch_mapy_resized = torch.cat(batch_mapy_resized) |
|
|
|
return batch_mapx_resized, batch_mapy_resized |
|
|
|
def _mask(self, faces, apply_rnd_mask): |
|
""" |
|
Notice that this masking function is designed specifically for EG3D canonical space (yaw and pitch equal to 0). |
|
If you change the coordinate system, change this too. |
|
""" |
|
B = faces.shape[0] |
|
N = faces.shape[-1] |
|
for i in range(B): |
|
mask_percent = 0.25 |
|
mask_size = int(mask_percent * N) |
|
mask = torch.zeros_like(faces[i: i + 1]) |
|
ones = torch.ones((1, 3, N - mask_size * 2, N - mask_size * 2), device=mask.device) |
|
mask[:, :, mask_size: N - mask_size, mask_size: N - mask_size] = ones |
|
faces[i: i + 1, ...] = faces[i: i + 1, ...] * mask |
|
|
|
if apply_rnd_mask: |
|
for _ in range(5): |
|
|
|
mask = torch.ones_like(faces[i: i + 1]) |
|
zeros = torch.zeros((1, 3, 64, 64), device=mask.device) |
|
x = np.random.randint(int(N * mask_percent), N // 2) |
|
y = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent))) |
|
mask[:, :, x: x + zeros.shape[-2], y: y + zeros.shape[-1]] = zeros |
|
faces[i: i + 1, ...] = faces[i: i + 1, ...] * mask - (1 - mask) |
|
|
|
return faces |
|
|
|
def _random_zoom_in(self, faces): |
|
size = faces.shape[-1] |
|
zoom_size_h = int(size * (0.7 + np.random.rand() * 0.3)) |
|
zoom_size_w = int(size * (0.7 + np.random.rand() * 0.3)) |
|
faces = K.center_crop(faces, (zoom_size_h, zoom_size_w)) |
|
faces = K.resize(faces, (size, size)) |
|
|
|
return faces |
|
|
|
def _random_zoom_out(self, faces): |
|
size = faces.shape[-1] |
|
pad_h = np.random.randint(int(size * 0.2)) |
|
pad_w = np.random.randint(int(size * 0.2)) |
|
faces = F.pad(faces, (pad_h, pad_w), mode='constant') |
|
faces = K.resize(faces, (size, size)) |
|
|
|
return faces |
|
|
|
def _random_color_patch(self, faces): |
|
mask_percent = 0.25 |
|
B = faces.shape[0] |
|
N = faces.shape[-1] |
|
|
|
aug_faces = self.color_jitter(faces) |
|
|
|
for i in range(B): |
|
for _ in range(20): |
|
x = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent))) |
|
y = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent))) |
|
|
|
subfaces = faces[i, :, x: x + 128, y: y + 128] |
|
subfaces = self.color_jitter(subfaces) |
|
|
|
aug_faces[i, :, x: x + 128, y: y + 128] = subfaces |
|
return aug_faces |
|
|
|
@torch.no_grad() |
|
def __call__(self, faces, target_size, apply_color_aug=True, apply_rnd_mask=True, apply_rnd_zoom=True): |
|
if target_size is not None: |
|
faces = F.interpolate(faces, size=target_size) |
|
|
|
if apply_color_aug: |
|
faces = self._random_color_patch(faces) |
|
|
|
zoom_type = 2 if not apply_rnd_zoom else np.random.randint(3) |
|
if zoom_type == 0: |
|
faces = self._random_zoom_in(faces) |
|
elif zoom_type == 1: |
|
faces = self._random_zoom_out(faces) |
|
|
|
faces = self._mask(faces, apply_rnd_mask) |
|
|
|
return faces |
|
|