ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
raw
history blame
3.19 kB
import torch
import torchvision.transforms as transforms
import cv2
import face_alignment
import numpy as np
from PIL import Image
from data_preprocessing.lm_based_image_align import LandmarkBasedImageAlign
from data_preprocessing.pose_estimation import PoseEstimator
from data_preprocessing.crop_image import ImageCropper
from data_preprocessing.foreground_extractor import ForegroundExtractor
class DataPreprocessor:
def __init__(self, device, crop_smooth_alpha=0.9):
self.device = device
self.crop_smooth_alpha = crop_smooth_alpha
self.face_alignment = LandmarkBasedImageAlign(output_size=1024, transform_size=1024)
self.pose_estimator = PoseEstimator(device)
self.cropper = ImageCropper()
self.foreground_extractor = ForegroundExtractor(device)
self.fa = face_alignment.FaceAlignment(
face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
)
self.transform = transforms.ToTensor()
self.lm_woAlign_ema = None
self.aligned_lm_ema = None
def __call__(self, img, keep_bg=False):
lm = self.fa.get_landmarks(np.array(img))
if lm is None:
detected_face = [0, 0, img.size[0], img.size[1]]
lm = self.fa.get_landmarks(img, detected_faces=[detected_face])[0]
else:
lm = lm[0]
if self.lm_woAlign_ema is None:
self.lm_woAlign_ema = np.array(lm)
self.lm_woAlign_ema = (self.lm_woAlign_ema * 0.9 + np.array(lm) * 0.1)
img_woAlign, lm_woAlign = self.cropper(img, self.lm_woAlign_ema)
lm_woAlign = np.array(lm_woAlign)
img_woAlign = np.array(img_woAlign)
img, aligned_lm = self.face_alignment(img, lm)
intrinsics, pose = self.pose_estimator.predict_pose(img, aligned_lm)
img, aligned_lm = self.cropper(img, aligned_lm)
img = np.array(img)
if not keep_bg:
matte = self.foreground_extractor(img)
img = (img * matte).astype(np.uint8)
if self.aligned_lm_ema is None:
self.aligned_lm_ema = aligned_lm
self.aligned_lm_ema = 0.8 * self.aligned_lm_ema + aligned_lm * 0.2
crop_params = cv2.estimateAffinePartial2D(self.aligned_lm_ema, lm_woAlign)[0]
return img, img_woAlign, intrinsics, pose, crop_params
def from_path(self, image_path, device, keep_bg=False):
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
img, _, intrinsics, pose, crop_params = self(img, keep_bg=keep_bg)
img = np.transpose(img, (2, 0, 1))[None, :, :, :] / 255.
img = (img * 2 - 1)
img = torch.from_numpy(img).float()
pose = torch.from_numpy(pose).unsqueeze(0).float()
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float()
return {
'image': img.to(device),
'cam2world': pose.to(device),
'intrinsics': intrinsics.to(device),
'crop_params': crop_params
}
@staticmethod
def realign(img, T):
out = cv2.warpAffine(img, T, (img.shape[1], img.shape[0]))
return out