minchul's picture
Upload directory
ce4d684 verified
from ..base import BaseAligner
from torchvision import transforms
from .dfa import get_landmark_predictor, get_preprocessor
from . import aligner_helper
import torch
import torch.nn.functional as F
import numpy as np
class DifferentiableFaceAligner(BaseAligner):
'''
A differentiable face aligner that aligns the image with one face to a canonical position.
The aligner is based on the following paper (check out supplementary material for more details):
@inproceedings{kim2024kprpe,
title={{KeyPoint Relative Position Encoding for Face Recognition},
author={Kim, Minchul and Su, Yiyang and Liu, Feng and Liu, Xiaoming},
booktitle={CVPR},
year={2024}
}
'''
def __init__(self, net, prior_box, preprocessor, config):
super(DifferentiableFaceAligner, self).__init__()
self.net = net
self.prior_box = prior_box
self.preprocessor = preprocessor
self.config = config
@classmethod
def from_config(cls, config):
net, prior_box = get_landmark_predictor(network=config.arch,
use_aggregator=True,
input_size=config.input_size)
preprocessor = get_preprocessor(output_size=config.input_size,
padding=config.input_padding_ratio,
padding_val=config.input_padding_val)
if config.freeze:
for param in net.parameters():
param.requires_grad = False
model = cls(net, prior_box, preprocessor, config)
model.eval()
return model
def forward(self, x, padding_ratio_override=None):
# input size check
assert x.shape[1] == 3
assert x.ndim == 4
assert isinstance(x, torch.Tensor)
is_square = x.shape[2] == x.shape[3]
x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
assert self.prior_box.image_size == x.shape[2:]
# make image into BGR
x_bgr = x.flip(1)
result = self.net(x_bgr, self.prior_box)
orig_pred_ldmks, bbox, cls = aligner_helper.split_network_output(result)
score = torch.nn.Softmax(dim=-1)(cls)[:,1:]
reference_ldmk = aligner_helper.reference_landmark()
input_size = self.config.input_size
output_size = self.config.output_size
cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
thetas = thetas.to(orig_pred_ldmks.device)
output_size = torch.Size((len(thetas), 3, output_size, output_size))
grid = F.affine_grid(thetas, output_size, align_corners=True)
aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
# bbox (xmin, ymin, xmax, ymax)
normalized_bbox = bbox / torch.tensor([[x_bgr.size(3), x_bgr.size(2)] * 2]).to(bbox.device)
if padding_ratio_override is None:
padding_ratio = self.preprocessor.padding
else:
padding_ratio = padding_ratio_override
if padding_ratio > 0:
# unpad the landmark so that it is in the original image coordinate
scale = 1 / (1 + (2 * padding_ratio))
pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
if not is_square:
unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
if not is_square:
orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
def make_train_transform(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
return transform
def make_test_transform(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
return transform