VOODOO3D-unofficial / models /utils /face_augmentor.py
ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
raw
history blame
6.16 kB
import torch
import torch.nn.functional as F
import numpy as np
import kornia.geometry.transform as K
from kornia.augmentation import ColorJiggle, AugmentationSequential
from kornia.augmentation import RandomChannelShuffle
class FaceAugmentor:
def __init__(self):
self.color_jitter = AugmentationSequential(
ColorJiggle(0.25, 0.3, 0.3, 0.3, p=1.),
RandomChannelShuffle(),
)
def _random_normal(self, size=(1,), trunc_val=2.5, rnd_state=None, device='cpu'):
if rnd_state is None:
rnd_state = np.random
len = np.array(size).prod()
result = np.empty((len,), dtype=np.float32)
for i in range(len):
while True:
x = rnd_state.normal()
if x >= -trunc_val and x <= trunc_val:
break
result[i] = (x / trunc_val)
return torch.from_numpy(result.reshape(size)).to(device)
def _get_warp_params(self, batch_size, img_size, device):
# random warp
batch_cell_size = np.random.choice([img_size // (2**i) for i in range(1, 4)], batch_size)
batch_cell_count = img_size // batch_cell_size + 1
batch_grid_points = [
torch.linspace(0, img_size, cell_count, device=device) for cell_count in batch_cell_count
]
batch_mapx = [
torch.broadcast_to(grid_points, (cell_count, cell_count)).clone()
for grid_points, cell_count in zip(batch_grid_points, batch_cell_count)
]
batch_mapy = [x.t() for x in batch_mapx]
batch_mapx_resized = []
batch_mapy_resized = []
for cell_size, cell_count, mapx, mapy in zip(batch_cell_size, batch_cell_count, batch_mapx, batch_mapy):
half_cell_size = cell_size // 2
mapx[1:-1, 1:-1] = mapx[1:-1, 1:-1] +\
self._random_normal(
size=(cell_count-2, cell_count-2),
device=device
) * (cell_size*0.24)
mapy[1:-1, 1:-1] = mapy[1:-1, 1:-1] +\
self._random_normal(
size=(cell_count-2, cell_count-2),
device=device
) * (cell_size*0.24)
img_size = int(img_size)
cell_size = int(cell_size)
mapx = F.interpolate(mapx.unsqueeze(0).unsqueeze(0), (img_size + cell_size,) * 2, mode='bilinear')[
:, 0, half_cell_size: -half_cell_size, half_cell_size: -half_cell_size
]
mapy = F.interpolate(mapy.unsqueeze(0).unsqueeze(0), (img_size + cell_size,) * 2, mode='bilinear')[
:, 0, half_cell_size: -half_cell_size, half_cell_size: -half_cell_size
]
batch_mapx_resized.append(mapx)
batch_mapy_resized.append(mapy)
batch_mapx_resized = torch.cat(batch_mapx_resized)
batch_mapy_resized = torch.cat(batch_mapy_resized)
return batch_mapx_resized, batch_mapy_resized
def _mask(self, faces, apply_rnd_mask):
"""
Notice that this masking function is designed specifically for EG3D canonical space (yaw and pitch equal to 0).
If you change the coordinate system, change this too.
"""
B = faces.shape[0]
N = faces.shape[-1]
for i in range(B):
mask_percent = 0.25
mask_size = int(mask_percent * N)
mask = torch.zeros_like(faces[i: i + 1])
ones = torch.ones((1, 3, N - mask_size * 2, N - mask_size * 2), device=mask.device)
mask[:, :, mask_size: N - mask_size, mask_size: N - mask_size] = ones
faces[i: i + 1, ...] = faces[i: i + 1, ...] * mask
if apply_rnd_mask:
for _ in range(5):
# 32x32 patch masking
mask = torch.ones_like(faces[i: i + 1])
zeros = torch.zeros((1, 3, 64, 64), device=mask.device)
x = np.random.randint(int(N * mask_percent), N // 2)
y = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent)))
mask[:, :, x: x + zeros.shape[-2], y: y + zeros.shape[-1]] = zeros
faces[i: i + 1, ...] = faces[i: i + 1, ...] * mask - (1 - mask)
return faces
def _random_zoom_in(self, faces):
size = faces.shape[-1]
zoom_size_h = int(size * (0.7 + np.random.rand() * 0.3))
zoom_size_w = int(size * (0.7 + np.random.rand() * 0.3))
faces = K.center_crop(faces, (zoom_size_h, zoom_size_w))
faces = K.resize(faces, (size, size))
return faces
def _random_zoom_out(self, faces):
size = faces.shape[-1]
pad_h = np.random.randint(int(size * 0.2))
pad_w = np.random.randint(int(size * 0.2))
faces = F.pad(faces, (pad_h, pad_w), mode='constant')
faces = K.resize(faces, (size, size))
return faces
def _random_color_patch(self, faces):
mask_percent = 0.25
B = faces.shape[0]
N = faces.shape[-1]
aug_faces = self.color_jitter(faces)
for i in range(B):
for _ in range(20):
x = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent)))
y = np.random.randint(int(N * mask_percent), int(N * (1 - mask_percent)))
subfaces = faces[i, :, x: x + 128, y: y + 128]
subfaces = self.color_jitter(subfaces)
aug_faces[i, :, x: x + 128, y: y + 128] = subfaces
return aug_faces
@torch.no_grad()
def __call__(self, faces, target_size, apply_color_aug=True, apply_rnd_mask=True, apply_rnd_zoom=True):
if target_size is not None:
faces = F.interpolate(faces, size=target_size)
if apply_color_aug:
faces = self._random_color_patch(faces)
zoom_type = 2 if not apply_rnd_zoom else np.random.randint(3)
if zoom_type == 0:
faces = self._random_zoom_in(faces)
elif zoom_type == 1:
faces = self._random_zoom_out(faces)
faces = self._mask(faces, apply_rnd_mask)
return faces