import torch import numpy as np from argparse import Namespace from additional_modules.deep3dfacerecon.util.load_mats import load_lm3d from additional_modules.deep3dfacerecon.util.preprocess import align_img from additional_modules.deep3dfacerecon.models.facerecon_model import FaceReconModel from data_preprocessing.crop_image import CENTER_CROP_SIZE class PoseEstimator: def __init__(self, device): self.opt = Namespace(**{ 'net_recon': 'resnet50', 'phase': 'test', 'init_path': None, 'use_last_fc': False, 'bfm_folder': 'additional_modules/deep3dfacerecon/BFM', 'bfm_model': 'BFM_model_front.mat', 'epoch': 20, 'checkpoint_path': 'pretrained_models/deep3dfacerecon_epoch20.pth', 'name': 'face_recon', 'device': device, 'camera_d': 10, 'focal': 1015, 'center': 112, }) self.device = device self.lm3d_std = load_lm3d(self.opt.bfm_folder) self.pose_predictor = FaceReconModel(self.opt) self.pose_predictor.load_networks(self.opt.checkpoint_path) self.pose_predictor.eval() def _pose_est_process_data(self, im, lm): W, H = im.size lm = lm.copy().reshape([-1, 2]) lm[:, -1] = H - 1 - lm[:, -1] _, im, lm, _, _, lm5p = align_img(im, lm, lm, self.lm3d_std) im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) lm = torch.tensor(lm).to(self.device).unsqueeze(0) return im, lm @staticmethod def compute_rotation(angles): """ Return: rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat Parameters: angles -- torch.tensor, size (B, 3), radian """ batch_size = angles.shape[0] ones = torch.ones([batch_size, 1]) zeros = torch.zeros([batch_size, 1]) x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], rot_x = torch.cat([ ones, zeros, zeros, zeros, torch.cos(x), -torch.sin(x), zeros, torch.sin(x), torch.cos(x) ], dim=1).reshape([batch_size, 3, 3]) rot_y = torch.cat([ torch.cos(y), zeros, torch.sin(y), zeros, ones, zeros, -torch.sin(y), zeros, torch.cos(y) ], dim=1).reshape([batch_size, 3, 3]) rot_z = torch.cat([ torch.cos(z), -torch.sin(z), zeros, torch.sin(z), torch.cos(z), zeros, zeros, zeros, ones ], dim=1).reshape([batch_size, 3, 3]) rot = rot_z @ rot_y @ rot_x return rot.permute(0, 2, 1) @torch.no_grad() def predict_pose(self, image, landmarks, batch_size=1): img_proc, lm_proc = self._pose_est_process_data(image, landmarks) data = { 'imgs': img_proc, 'lms': lm_proc, } self.pose_predictor.set_input(data) self.pose_predictor.test() pred_coeffs_dict_tensor = self.pose_predictor.pred_coeffs_dict pred_coeffs_dict = {k: v.cpu().numpy() for k, v in pred_coeffs_dict_tensor.items()} angle = pred_coeffs_dict['angle'] R = self.compute_rotation(torch.from_numpy(angle))[0] # Extrinsics trans = pred_coeffs_dict['trans'][0] trans[2] += -10 c = -np.dot(R, trans) c *= 0.27 # normalize camera radius c[1] += 0.006 # additional offset used in submission c[2] += 0.161 # additional offset used in submission radius = np.linalg.norm(c) c = c / radius * 2.7 Rot = np.eye(3) Rot[0, 0] = 1 Rot[1, 1] = -1 Rot[2, 2] = -1 R = np.dot(R, Rot) pose = np.eye(4) pose[0, 3] = c[0] pose[1, 3] = c[1] pose[2, 3] = c[2] pose[:3, :3] = R # Intrinsics focal = 2985.29 / CENTER_CROP_SIZE cx, cy = 0.5, 0.5 K = np.eye(3) K[0][0] = focal K[1][1] = focal K[0][2] = cx K[1][2] = cy return K, pose if __name__ == '__main__': from PIL import Image from face_detector import FaceDetector device = 'cuda' face_detector = FaceDetector(device) pose_estimator = PoseEstimator(device) img = Image.open('00000.png').convert('RGB') # with open('ffhq-dataset-v2.json', 'r') as f: # gt_lm = np.array(json.load(f)['0']['in_the_wild']['face_landmarks']) lm, _ = face_detector(img) pose_estimator.predict_pose_batch([img], [lm]) pose_estimator.debug()