Spaces:
Runtime error
Runtime error
''' | |
# author: Zhiyuan Yan | |
# email: [email protected] | |
# date: 2023-03-30 | |
The code is designed for Face X-ray. | |
''' | |
import os | |
import sys | |
import json | |
import pickle | |
import time | |
import lmdb | |
import numpy as np | |
import albumentations as A | |
import cv2 | |
import random | |
from PIL import Image | |
from skimage.util import random_noise | |
from scipy import linalg | |
import heapq as hq | |
import lmdb | |
import torch | |
from torch.autograd import Variable | |
from torch.utils import data | |
from torchvision import transforms as T | |
import torchvision | |
from dataset.utils.face_blend import * | |
from dataset.utils.face_align import get_align_mat_new | |
from dataset.utils.color_transfer import color_transfer | |
from dataset.utils.faceswap_utils import blendImages as alpha_blend_fea | |
from dataset.utils.faceswap_utils import AlphaBlend as alpha_blend | |
from dataset.utils.face_aug import aug_one_im, change_res | |
from dataset.utils.image_ae import get_pretraiend_ae | |
from dataset.utils.warp import warp_mask | |
from dataset.utils import faceswap | |
from scipy.ndimage.filters import gaussian_filter | |
class RandomDownScale(A.core.transforms_interface.ImageOnlyTransform): | |
def apply(self,img,**params): | |
return self.randomdownscale(img) | |
def randomdownscale(self,img): | |
keep_ratio=True | |
keep_input_shape=True | |
H,W,C=img.shape | |
ratio_list=[2,4] | |
r=ratio_list[np.random.randint(len(ratio_list))] | |
img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST) | |
if keep_input_shape: | |
img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR) | |
return img_ds | |
class FFBlendDataset(data.Dataset): | |
def __init__(self, config=None): | |
self.lmdb = config.get('lmdb', False) | |
if self.lmdb: | |
lmdb_path = os.path.join(config['lmdb_dir'], f"FaceForensics++_lmdb") | |
self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False) | |
# Check if the dictionary has already been created | |
if os.path.exists('training/lib/nearest_face_info.pkl'): | |
with open('training/lib/nearest_face_info.pkl', 'rb') as f: | |
face_info = pickle.load(f) | |
else: | |
raise ValueError(f"Need to run the dataset/generate_xray_nearest.py before training the face xray.") | |
self.face_info = face_info | |
# Check if the dictionary has already been created | |
if os.path.exists('training/lib/landmark_dict_ffall.pkl'): | |
with open('training/lib/landmark_dict_ffall.pkl', 'rb') as f: | |
landmark_dict = pickle.load(f) | |
self.landmark_dict = landmark_dict | |
self.imid_list = self.get_training_imglist() | |
self.transforms = T.Compose([ | |
# T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), | |
# T.ColorJitter(hue=.05, saturation=.05), | |
# T.RandomHorizontalFlip(), | |
# T.RandomRotation(20, resample=Image.BILINEAR), | |
T.ToTensor(), | |
T.Normalize(mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5]) | |
]) | |
self.data_dict = { | |
'imid_list': self.imid_list | |
} | |
self.config=config | |
# def data_aug(self, im): | |
# """ | |
# Apply data augmentation on the input image. | |
# """ | |
# transform = T.Compose([ | |
# T.ToPILImage(), | |
# T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), | |
# T.ColorJitter(hue=.05, saturation=.05), | |
# ]) | |
# # Apply transformations | |
# im_aug = transform(im) | |
# return im_aug | |
def blended_aug(self, im): | |
transform = A.Compose([ | |
A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), | |
A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3), | |
A.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3), | |
A.ImageCompression(quality_lower=40, quality_upper=100,p=0.5) | |
]) | |
# Apply transformations | |
im_aug = transform(image=im) | |
return im_aug['image'] | |
def data_aug(self, im): | |
""" | |
Apply data augmentation on the input image using albumentations. | |
""" | |
transform = A.Compose([ | |
A.Compose([ | |
A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3), | |
A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1), | |
A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1), | |
],p=1), | |
A.OneOf([ | |
RandomDownScale(p=1), | |
A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1), | |
],p=1), | |
], p=1.) | |
# Apply transformations | |
im_aug = transform(image=im) | |
return im_aug['image'] | |
def get_training_imglist(self): | |
""" | |
Get the list of training images. | |
""" | |
random.seed(1024) # Fix the random seed for reproducibility | |
imid_list = list(self.landmark_dict.keys()) | |
# imid_list = [imid.replace('landmarks', 'frames').replace('npy', 'png') for imid in imid_list] | |
random.shuffle(imid_list) | |
return imid_list | |
def load_rgb(self, file_path): | |
""" | |
Load an RGB image from a file path and resize it to a specified resolution. | |
Args: | |
file_path: A string indicating the path to the image file. | |
Returns: | |
An Image object containing the loaded and resized image. | |
Raises: | |
ValueError: If the loaded image is None. | |
""" | |
size = self.config['resolution'] # if self.mode == "train" else self.config['resolution'] | |
if not self.lmdb: | |
if not file_path[0] == '.': | |
file_path = f'./{self.config["rgb_dir"]}\\'+file_path | |
assert os.path.exists(file_path), f"{file_path} does not exist" | |
img = cv2.imread(file_path) | |
if img is None: | |
raise ValueError('Loaded image is None: {}'.format(file_path)) | |
elif self.lmdb: | |
with self.env.begin(write=False) as txn: | |
# transfer the path format from rgb-path to lmdb-key | |
if file_path[0]=='.': | |
file_path=file_path.replace('./datasets\\','') | |
image_bin = txn.get(file_path.encode()) | |
image_buf = np.frombuffer(image_bin, dtype=np.uint8) | |
img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) | |
return np.array(img, dtype=np.uint8) | |
def load_mask(self, file_path): | |
""" | |
Load a binary mask image from a file path and resize it to a specified resolution. | |
Args: | |
file_path: A string indicating the path to the mask file. | |
Returns: | |
A numpy array containing the loaded and resized mask. | |
Raises: | |
None. | |
""" | |
size = self.config['resolution'] | |
if file_path is None: | |
if not file_path[0] == '.': | |
file_path = f'./{self.config["rgb_dir"]}\\'+file_path | |
return np.zeros((size, size, 1)) | |
if not self.lmdb: | |
if os.path.exists(file_path): | |
mask = cv2.imread(file_path, 0) | |
if mask is None: | |
mask = np.zeros((size, size)) | |
else: | |
return np.zeros((size, size, 1)) | |
else: | |
with self.env.begin(write=False) as txn: | |
# transfer the path format from rgb-path to lmdb-key | |
if file_path[0]=='.': | |
file_path=file_path.replace('./datasets\\','') | |
image_bin = txn.get(file_path.encode()) | |
image_buf = np.frombuffer(image_bin, dtype=np.uint8) | |
# cv2.IMREAD_GRAYSCALE为灰度图,cv2.IMREAD_COLOR为彩色图 | |
mask = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) | |
mask = cv2.resize(mask, (size, size)) / 255 | |
mask = np.expand_dims(mask, axis=2) | |
return np.float32(mask) | |
def load_landmark(self, file_path): | |
""" | |
Load 2D facial landmarks from a file path. | |
Args: | |
file_path: A string indicating the path to the landmark file. | |
Returns: | |
A numpy array containing the loaded landmarks. | |
Raises: | |
None. | |
""" | |
if file_path is None: | |
return np.zeros((81, 2)) | |
if not self.lmdb: | |
if not file_path[0] == '.': | |
file_path = f'./{self.config["rgb_dir"]}\\'+file_path | |
if os.path.exists(file_path): | |
landmark = np.load(file_path) | |
else: | |
return np.zeros((81, 2)) | |
else: | |
with self.env.begin(write=False) as txn: | |
# transfer the path format from rgb-path to lmdb-key | |
if file_path[0]=='.': | |
file_path=file_path.replace('./datasets\\','') | |
binary = txn.get(file_path.encode()) | |
landmark = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2)) | |
return np.float32(landmark) | |
def preprocess_images(self, imid_fg, imid_bg): | |
""" | |
Load foreground and background images and face shapes. | |
""" | |
fg_im = self.load_rgb(imid_fg.replace('landmarks', 'frames').replace('npy', 'png')) | |
fg_im = np.array(self.data_aug(fg_im)) | |
fg_shape = self.landmark_dict[imid_fg] | |
fg_shape = np.array(fg_shape, dtype=np.int32) | |
bg_im = self.load_rgb(imid_bg.replace('landmarks', 'frames').replace('npy', 'png')) | |
bg_im = np.array(self.data_aug(bg_im)) | |
bg_shape = self.landmark_dict[imid_bg] | |
bg_shape = np.array(bg_shape, dtype=np.int32) | |
if fg_im is None: | |
return bg_im, bg_shape, bg_im, bg_shape | |
elif bg_im is None: | |
return fg_im, fg_shape, fg_im, fg_shape | |
return fg_im, fg_shape, bg_im, bg_shape | |
def get_fg_bg(self, one_lmk_path): | |
""" | |
Get foreground and background paths | |
""" | |
bg_lmk_path = one_lmk_path | |
# Randomly pick one from the nearest neighbors for the foreground | |
if bg_lmk_path in self.face_info: | |
fg_lmk_path = random.choice(self.face_info[bg_lmk_path]) | |
else: | |
fg_lmk_path = bg_lmk_path | |
return fg_lmk_path, bg_lmk_path | |
def generate_masks(self, fg_im, fg_shape, bg_im, bg_shape): | |
""" | |
Generate masks for foreground and background images. | |
""" | |
fg_mask = get_mask(fg_shape, fg_im, deform=False) | |
bg_mask = get_mask(bg_shape, bg_im, deform=True) | |
# # Only do the postprocess for the background mask | |
bg_mask_postprocess = warp_mask(bg_mask, std=20) | |
return fg_mask, bg_mask_postprocess | |
def warp_images(self, fg_im, fg_shape, bg_im, bg_shape, fg_mask): | |
""" | |
Warp foreground face onto background image using affine or 3D warping. | |
""" | |
H, W, C = bg_im.shape | |
use_3d_warp = np.random.rand() < 0.5 | |
if not use_3d_warp: | |
aff_param = np.array(get_align_mat_new(fg_shape, bg_shape)).reshape(2, 3) | |
warped_face = cv2.warpAffine(fg_im, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) | |
fg_mask = cv2.warpAffine(fg_mask, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT) | |
fg_mask = fg_mask > 0 | |
else: | |
warped_face = faceswap.warp_image_3d(fg_im, np.array(fg_shape[:48]), np.array(bg_shape[:48]), (H, W)) | |
fg_mask = np.mean(warped_face, axis=2) > 0 | |
return warped_face, fg_mask | |
def colorTransfer(self, src, dst, mask): | |
transferredDst = np.copy(dst) | |
maskIndices = np.where(mask != 0) | |
maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.float32) | |
maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.float32) | |
# Compute means and standard deviations | |
meanSrc = np.mean(maskedSrc, axis=0) | |
stdSrc = np.std(maskedSrc, axis=0) | |
meanDst = np.mean(maskedDst, axis=0) | |
stdDst = np.std(maskedDst, axis=0) | |
# Perform color transfer | |
maskedDst = (maskedDst - meanDst) * (stdSrc / stdDst) + meanSrc | |
maskedDst = np.clip(maskedDst, 0, 255) | |
# Copy the entire background into transferredDst | |
transferredDst = np.copy(dst) | |
# Now apply color transfer only to the masked region | |
transferredDst[maskIndices[0], maskIndices[1]] = maskedDst.astype(np.uint8) | |
return transferredDst | |
def blend_images(self, color_corrected_fg, bg_im, bg_mask, featherAmount=0.2): | |
""" | |
Blend foreground and background images together. | |
""" | |
# normalize the mask to have values between 0 and 1 | |
b_mask = bg_mask / 255. | |
# Add an extra dimension and repeat the mask to match the number of channels in color_corrected_fg and bg_im | |
b_mask = np.repeat(b_mask[:, :, np.newaxis], 3, axis=2) | |
# Compute the alpha blending | |
maskIndices = np.where(b_mask != 0) | |
maskPts = np.hstack((maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis])) | |
# FIXME: deal with the bugs of empty maskpts | |
if maskPts.size == 0: | |
print(f"No non-zero values found in bg_mask for blending. Skipping this image.") | |
return color_corrected_fg # or handle this situation differently according to the needs | |
faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0) | |
featherAmount = featherAmount * np.max(faceSize) | |
hull = cv2.convexHull(maskPts) | |
dists = np.zeros(maskPts.shape[0]) | |
for i in range(maskPts.shape[0]): | |
dists[i] = cv2.pointPolygonTest(hull, (int(maskPts[i, 0]), int(maskPts[i, 1])), True) | |
weights = np.clip(dists / featherAmount, 0, 1) | |
# Perform the blending operation | |
color_corrected_fg = color_corrected_fg.astype(float) | |
bg_im = bg_im.astype(float) | |
blended_image = np.copy(bg_im) | |
blended_image[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * color_corrected_fg[maskIndices[0], maskIndices[1]] + (1 - weights[:, np.newaxis]) * bg_im[maskIndices[0], maskIndices[1]] | |
# Convert the blended image to 8-bit unsigned integers | |
blended_image = np.clip(blended_image, 0, 255) | |
blended_image = blended_image.astype(np.uint8) | |
return blended_image | |
def process_images(self, imid_fg, imid_bg, index): | |
""" | |
Overview: | |
Process foreground and background images following the data generation pipeline (BI dataset). | |
Terminology: | |
Foreground (fg) image: The image containing the face that will be blended onto the background image. | |
Background (bg) image: The image onto which the face from the foreground image will be blended. | |
""" | |
fg_im, fg_shape, bg_im, bg_shape = self.preprocess_images(imid_fg, imid_bg) | |
fg_mask, bg_mask = self.generate_masks(fg_im, fg_shape, bg_im, bg_shape) | |
warped_face, fg_mask = self.warp_images(fg_im, fg_shape, bg_im, bg_shape, fg_mask) | |
try: | |
# add the below two lines to make sure the bg_mask is strictly within the fg_mask | |
bg_mask[fg_mask == 0] = 0 | |
color_corrected_fg = self.colorTransfer(bg_im, warped_face, bg_mask) | |
blended_image = self.blend_images(color_corrected_fg, bg_im, bg_mask) | |
# FIXME: ugly, in order to fix the problem of mask (all zero values for bg_mask) | |
except: | |
color_corrected_fg = self.colorTransfer(bg_im, warped_face, bg_mask) | |
blended_image = self.blend_images(color_corrected_fg, bg_im, bg_mask) | |
boundary = get_boundary(bg_mask) | |
# # Prepare images and titles for the combined image | |
# images = [fg_im, np.where(fg_mask>0, 255, 0), bg_im, bg_mask, color_corrected_fg, blended_image, np.where(boundary>0, 255, 0)] | |
# titles = ["Fg Image", "Fg Mask", "Bg Image", | |
# "Bg Mask", "Blended Region", | |
# "Blended Image", "Boundary"] | |
# # Save the combined image | |
# os.makedirs('facexray_examples_3', exist_ok=True) | |
# self.save_combined_image(images, titles, index, f'facexray_examples_3/combined_image_{index}.png') | |
return blended_image, boundary, bg_im | |
def post_proc(self, img): | |
''' | |
if self.mode == 'train': | |
#if np.random.rand() < 0.5: | |
# img = random_add_noise(img) | |
#add_gaussian_noise(img) | |
if np.random.rand() < 0.5: | |
#img, _ = change_res(img) | |
img = gaussian_blur(img) | |
''' | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
im_aug = self.blended_aug(img) | |
im_aug = Image.fromarray(np.uint8(img)) | |
im_aug = self.transforms(im_aug) | |
return im_aug | |
def save_combined_image(images, titles, index, save_path): | |
""" | |
Save the combined image with titles for each single image. | |
Args: | |
images (List[np.ndarray]): List of images to be combined. | |
titles (List[str]): List of titles for each image. | |
index (int): Index of the image. | |
save_path (str): Path to save the combined image. | |
""" | |
# Determine the maximum height and width among the images | |
max_height = max(image.shape[0] for image in images) | |
max_width = max(image.shape[1] for image in images) | |
# Create the canvas | |
canvas = np.zeros((max_height * len(images), max_width, 3), dtype=np.uint8) | |
# Place the images and titles on the canvas | |
current_height = 0 | |
for image, title in zip(images, titles): | |
height, width = image.shape[:2] | |
# Check if image has a third dimension (color channels) | |
if image.ndim == 2: | |
# If not, add a third dimension | |
image = np.tile(image[..., None], (1, 1, 3)) | |
canvas[current_height : current_height + height, :width] = image | |
cv2.putText( | |
canvas, title, (10, current_height + 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2 | |
) | |
current_height += height | |
# Save the combined image | |
cv2.imwrite(save_path, canvas) | |
def __getitem__(self, index): | |
""" | |
Get an item from the dataset by index. | |
""" | |
one_lmk_path = self.imid_list[index] | |
try: | |
label = 1 if one_lmk_path.split('/')[6]=='manipulated_sequences' else 0 | |
except Exception as e: | |
label = 1 if one_lmk_path.split('\\')[6] == 'manipulated_sequences' else 0 | |
imid_fg, imid_bg = self.get_fg_bg(one_lmk_path) | |
manipulate_img, boundary, imid_bg = self.process_images(imid_fg, imid_bg, index) | |
manipulate_img = self.post_proc(manipulate_img) | |
imid_bg = self.post_proc(imid_bg) | |
boundary = torch.from_numpy(boundary) | |
boundary = boundary.unsqueeze(2).permute(2, 0, 1) | |
# fake data | |
fake_data_tuple = (manipulate_img, boundary, 1) | |
# real data | |
real_data_tuple = (imid_bg, torch.zeros_like(boundary), label) | |
return fake_data_tuple, real_data_tuple | |
def collate_fn(batch): | |
""" | |
Collates batches of data and shuffles the images. | |
""" | |
# Unzip the batch | |
fake_data, real_data = zip(*batch) | |
# Unzip the fake and real data | |
fake_images, fake_boundaries, fake_labels = zip(*fake_data) | |
real_images, real_boundaries, real_labels = zip(*real_data) | |
# Combine fake and real data | |
images = torch.stack(fake_images + real_images) | |
boundaries = torch.stack(fake_boundaries + real_boundaries) | |
labels = torch.tensor(fake_labels + real_labels) | |
# Combine images, boundaries, and labels into tuples | |
combined_data = list(zip(images, boundaries, labels)) | |
# Shuffle the combined data | |
random.shuffle(combined_data) | |
# Unzip the shuffled data | |
images, boundaries, labels = zip(*combined_data) | |
# Create the data dictionary | |
data_dict = { | |
'image': torch.stack(images), | |
'label': torch.tensor(labels), | |
'mask': torch.stack(boundaries), # Assuming boundaries are your masks | |
'landmark': None # Add your landmark data if available | |
} | |
return data_dict | |
def __len__(self): | |
""" | |
Get the length of the dataset. | |
""" | |
return len(self.imid_list) | |
if __name__ == "__main__": | |
dataset = FFBlendDataset() | |
print('dataset lenth: ', len(dataset)) | |
def tensor2bgr(im): | |
img = im.squeeze().cpu().numpy().transpose(1, 2, 0) | |
img = (img + 1)/2 * 255 | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
return img | |
def tensor2gray(im): | |
img = im.squeeze().cpu().numpy() | |
img = img * 255 | |
return img | |
for i, data_dict in enumerate(dataset): | |
if i > 20: | |
break | |
if label == 1: | |
if not use_mouth: | |
img, boudary = im | |
cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img)) | |
cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary)) | |
else: | |
img, mouth, boudary = im | |
cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img)) | |
cv2.imwrite('{}_mouth.png'.format(i), tensor2bgr(mouth)) | |
cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary)) | |