Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import torch | |
from insightface.app import FaceAnalysis | |
from insightface.utils import face_align | |
from PIL import Image | |
from torchvision import models, transforms | |
from curricularface import get_model | |
import cv2 | |
import numpy as np | |
import numpy | |
def matrix_sqrt(matrix): | |
eigenvalues, eigenvectors = torch.linalg.eigh(matrix) | |
sqrt_eigenvalues = torch.sqrt(torch.clamp(eigenvalues, min=0)) | |
sqrt_matrix = (eigenvectors * sqrt_eigenvalues).mm(eigenvectors.T) | |
return sqrt_matrix | |
def sample_video_frames(video_path, num_frames=16): | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
frames = [] | |
for idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if ret: | |
# print(frame.shape) | |
#if frame.shape[1] > 1024: | |
# frame = frame[:, 1440:, :] | |
# print(frame.shape) | |
frames.append(frame) | |
cap.release() | |
return frames | |
def get_face_keypoints(face_model, image_bgr): | |
face_info = face_model.get(image_bgr) | |
if len(face_info) > 0: | |
return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] | |
return None | |
def load_image(image): | |
img = image.convert('RGB') | |
img = transforms.Resize((299, 299))(img) # Resize to Inception input size | |
img = transforms.ToTensor()(img) | |
return img.unsqueeze(0) # Add batch dimension | |
def calculate_fid(real_activations, fake_activations, device="cuda"): | |
real_activations_tensor = torch.tensor(real_activations).to(device) | |
fake_activations_tensor = torch.tensor(fake_activations).to(device) | |
mu1 = real_activations_tensor.mean(dim=0) | |
sigma1 = torch.cov(real_activations_tensor.T) | |
mu2 = fake_activations_tensor.mean(dim=0) | |
sigma2 = torch.cov(fake_activations_tensor.T) | |
ssdiff = torch.sum((mu1 - mu2) ** 2) | |
covmean = matrix_sqrt(sigma1.mm(sigma2)) | |
if torch.is_complex(covmean): | |
covmean = covmean.real | |
fid = ssdiff + torch.trace(sigma1 + sigma2 - 2 * covmean) | |
return fid.item() | |
def batch_cosine_similarity(embedding_image, embedding_frames, device="cuda"): | |
embedding_image = torch.tensor(embedding_image).to(device) | |
embedding_frames = torch.tensor(embedding_frames).to(device) | |
return torch.nn.functional.cosine_similarity(embedding_image, embedding_frames, dim=-1).cpu().numpy() | |
def get_activations(images, model, batch_size=16): | |
model.eval() | |
activations = [] | |
with torch.no_grad(): | |
for i in range(0, len(images), batch_size): | |
batch = images[i:i + batch_size] | |
pred = model(batch) | |
activations.append(pred) | |
activations = torch.cat(activations, dim=0).cpu().numpy() | |
if activations.shape[0] == 1: | |
activations = np.repeat(activations, 2, axis=0) | |
return activations | |
def pad_np_bgr_image(np_image, scale=1.25): | |
assert scale >= 1.0, "scale should be >= 1.0" | |
pad_scale = scale - 1.0 | |
h, w = np_image.shape[:2] | |
top = bottom = int(h * pad_scale) | |
left = right = int(w * pad_scale) | |
return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top) | |
def process_image(face_model, image_path): | |
if isinstance(image_path, str): | |
np_faceid_image = np.array(Image.open(image_path).convert("RGB")) | |
elif isinstance(image_path, numpy.ndarray): | |
np_faceid_image = image_path | |
else: | |
raise TypeError("image_path should be a string or PIL.Image.Image object") | |
image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR) | |
face_info = get_face_keypoints(face_model, image_bgr) | |
if face_info is None: | |
padded_image, sub_coord = pad_np_bgr_image(image_bgr) | |
face_info = get_face_keypoints(face_model, padded_image) | |
if face_info is None: | |
print("Warning: No face detected in the image. Continuing processing...") | |
return None, None | |
face_kps = face_info['kps'] | |
face_kps -= np.array(sub_coord) | |
else: | |
face_kps = face_info['kps'] | |
arcface_embedding = face_info['embedding'] | |
# print(face_kps) | |
norm_face = face_align.norm_crop(image_bgr, landmark=face_kps, image_size=224) | |
align_face = cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB) | |
return align_face, arcface_embedding | |
def inference(face_model, img, device): | |
img = cv2.resize(img, (112, 112)) | |
img = np.transpose(img, (2, 0, 1)) | |
img = torch.from_numpy(img).unsqueeze(0).float().to(device) | |
img.div_(255).sub_(0.5).div_(0.5) | |
embedding = face_model(img).detach().cpu().numpy()[0] | |
return embedding / np.linalg.norm(embedding) | |
def process_video(video_path, face_arc_model, face_cur_model, fid_model, arcface_image_embedding, cur_image_embedding, real_activations, device): | |
video_frames = sample_video_frames(video_path, num_frames=16) | |
#print(video_frames) | |
# Initialize lists to store the scores | |
cur_scores = [] | |
arc_scores = [] | |
fid_face = [] | |
for frame in video_frames: | |
# Convert to RGB once at the beginning | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# Process the frame for ArcFace embeddings | |
align_face_frame, arcface_frame_embedding = process_image(face_arc_model, frame_rgb) | |
# Skip if alignment fails | |
if align_face_frame is None: | |
continue | |
# Perform inference for current face model | |
cur_embedding_frame = inference(face_cur_model, align_face_frame, device) | |
# Compute cosine similarity for cur_score and arc_score in a compact manner | |
cur_score = max(0.0, batch_cosine_similarity(cur_image_embedding, cur_embedding_frame, device=device).item()) | |
arc_score = max(0.0, batch_cosine_similarity(arcface_image_embedding, arcface_frame_embedding, device=device).item()) | |
# Process FID score | |
align_face_frame_pil = Image.fromarray(align_face_frame) | |
fake_image = load_image(align_face_frame_pil).to(device) | |
fake_activations = get_activations(fake_image, fid_model) | |
fid_score = calculate_fid(real_activations, fake_activations, device) | |
# Collect scores | |
fid_face.append(fid_score) | |
cur_scores.append(cur_score) | |
arc_scores.append(arc_score) | |
# Aggregate results with default values for empty lists | |
avg_cur_score = np.mean(cur_scores) if cur_scores else 0.0 | |
avg_arc_score = np.mean(arc_scores) if arc_scores else 0.0 | |
avg_fid_score = np.mean(fid_face) if fid_face else 0.0 | |
return avg_cur_score, avg_arc_score, avg_fid_score | |
def main(): | |
device = "cuda" | |
# data_path = "data/SkyActor" | |
# data_path = "data/LivePotraits" | |
# data_path = "data/Actor-One" | |
data_path = "data/FollowYourEmoji" | |
img_path = "/maindata/data/shared/public/rui.wang/act_review/ref_images" | |
pre_tag = False | |
mp4_list = os.listdir(data_path) | |
print(mp4_list) | |
img_list = [] | |
video_list = [] | |
for mp4 in mp4_list: | |
if "mp4" not in mp4: | |
continue | |
if pre_tag: | |
png_path = mp4.split('.')[0].split('-')[0] + ".png" | |
else: | |
if "-" in mp4: | |
png_path = mp4.split('.')[0].split('-')[1] + ".png" | |
else: | |
png_path = mp4.split('.')[0].split('_')[1] + ".png" | |
img_list.append(os.path.join(img_path, png_path)) | |
video_list.append(os.path.join(data_path, mp4)) | |
print(img_list) | |
print(video_list[0]) | |
model_path = "eval" | |
face_arc_path = os.path.join(model_path, "face_encoder") | |
face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin") | |
# Initialize FaceEncoder model for face detection and embedding extraction | |
face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider']) | |
face_arc_model.prepare(ctx_id=0, det_size=(320, 320)) | |
# Load face recognition model | |
face_cur_model = get_model('IR_101')([112, 112]) | |
face_cur_model.load_state_dict(torch.load(face_cur_path, map_location="cpu")) | |
face_cur_model = face_cur_model.to(device) | |
face_cur_model.eval() | |
# Load InceptionV3 model for FID calculation | |
fid_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT) | |
fid_model.fc = torch.nn.Identity() # Remove final classification layer | |
fid_model.eval() | |
fid_model = fid_model.to(device) | |
# Process the single video and image pair | |
# Extract embeddings and features from the image | |
cur_list, arc_list, fid_list = [], [], [] | |
for i in range(len(img_list)): | |
align_face_image, arcface_image_embedding = process_image(face_arc_model, img_list[i]) | |
cur_image_embedding = inference(face_cur_model, align_face_image, device) | |
align_face_image_pil = Image.fromarray(align_face_image) | |
real_image = load_image(align_face_image_pil).to(device) | |
real_activations = get_activations(real_image, fid_model) | |
# Process the video and calculate scores | |
cur_score, arc_score, fid_score = process_video( | |
video_list[i], face_arc_model, face_cur_model, fid_model, | |
arcface_image_embedding, cur_image_embedding, real_activations, device | |
) | |
print(cur_score, arc_score, fid_score) | |
cur_list.append(cur_score) | |
arc_list.append(arc_score) | |
fid_list.append(fid_score) | |
# break | |
print("cur", sum(cur_list)/ len(cur_list)) | |
print("arc", sum(arc_list)/ len(arc_list)) | |
print("fid", sum(fid_list)/ len(fid_list)) | |
main() | |