import torch import torchvision.transforms as transforms import cv2 import face_alignment import numpy as np from PIL import Image from data_preprocessing.lm_based_image_align import LandmarkBasedImageAlign from data_preprocessing.pose_estimation import PoseEstimator from data_preprocessing.crop_image import ImageCropper from data_preprocessing.foreground_extractor import ForegroundExtractor class DataPreprocessor: def __init__(self, device, crop_smooth_alpha=0.9): self.device = device self.crop_smooth_alpha = crop_smooth_alpha self.face_alignment = LandmarkBasedImageAlign(output_size=1024, transform_size=1024) self.pose_estimator = PoseEstimator(device) self.cropper = ImageCropper() self.foreground_extractor = ForegroundExtractor(device) self.fa = face_alignment.FaceAlignment( face_alignment.LandmarksType.TWO_D, flip_input=False, device=device ) self.transform = transforms.ToTensor() self.lm_woAlign_ema = None self.aligned_lm_ema = None def __call__(self, img, keep_bg=False): lm = self.fa.get_landmarks(np.array(img)) if lm is None: detected_face = [0, 0, img.size[0], img.size[1]] lm = self.fa.get_landmarks(img, detected_faces=[detected_face])[0] else: lm = lm[0] if self.lm_woAlign_ema is None: self.lm_woAlign_ema = np.array(lm) self.lm_woAlign_ema = (self.lm_woAlign_ema * 0.9 + np.array(lm) * 0.1) img_woAlign, lm_woAlign = self.cropper(img, self.lm_woAlign_ema) lm_woAlign = np.array(lm_woAlign) img_woAlign = np.array(img_woAlign) img, aligned_lm = self.face_alignment(img, lm) intrinsics, pose = self.pose_estimator.predict_pose(img, aligned_lm) img, aligned_lm = self.cropper(img, aligned_lm) img = np.array(img) if not keep_bg: matte = self.foreground_extractor(img) img = (img * matte).astype(np.uint8) if self.aligned_lm_ema is None: self.aligned_lm_ema = aligned_lm self.aligned_lm_ema = 0.8 * self.aligned_lm_ema + aligned_lm * 0.2 crop_params = cv2.estimateAffinePartial2D(self.aligned_lm_ema, lm_woAlign)[0] return img, img_woAlign, intrinsics, pose, crop_params def from_path(self, image_path, device, keep_bg=False): img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) img, _, intrinsics, pose, crop_params = self(img, keep_bg=keep_bg) img = np.transpose(img, (2, 0, 1))[None, :, :, :] / 255. img = (img * 2 - 1) img = torch.from_numpy(img).float() pose = torch.from_numpy(pose).unsqueeze(0).float() intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float() return { 'image': img.to(device), 'cam2world': pose.to(device), 'intrinsics': intrinsics.to(device), 'crop_params': crop_params } @staticmethod def realign(img, T): out = cv2.warpAffine(img, T, (img.shape[1], img.shape[0])) return out