multimodalart's picture
Upload 83 files
38e20ed verified
import os
import torch
from insightface.app import FaceAnalysis
from insightface.utils import face_align
from PIL import Image
from torchvision import models, transforms
from curricularface import get_model
import cv2
import numpy as np
import numpy
def matrix_sqrt(matrix):
eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
sqrt_eigenvalues = torch.sqrt(torch.clamp(eigenvalues, min=0))
sqrt_matrix = (eigenvectors * sqrt_eigenvalues).mm(eigenvectors.T)
return sqrt_matrix
def sample_video_frames(video_path, num_frames=16):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# print(frame.shape)
#if frame.shape[1] > 1024:
# frame = frame[:, 1440:, :]
# print(frame.shape)
frames.append(frame)
cap.release()
return frames
def get_face_keypoints(face_model, image_bgr):
face_info = face_model.get(image_bgr)
if len(face_info) > 0:
return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
return None
def load_image(image):
img = image.convert('RGB')
img = transforms.Resize((299, 299))(img) # Resize to Inception input size
img = transforms.ToTensor()(img)
return img.unsqueeze(0) # Add batch dimension
def calculate_fid(real_activations, fake_activations, device="cuda"):
real_activations_tensor = torch.tensor(real_activations).to(device)
fake_activations_tensor = torch.tensor(fake_activations).to(device)
mu1 = real_activations_tensor.mean(dim=0)
sigma1 = torch.cov(real_activations_tensor.T)
mu2 = fake_activations_tensor.mean(dim=0)
sigma2 = torch.cov(fake_activations_tensor.T)
ssdiff = torch.sum((mu1 - mu2) ** 2)
covmean = matrix_sqrt(sigma1.mm(sigma2))
if torch.is_complex(covmean):
covmean = covmean.real
fid = ssdiff + torch.trace(sigma1 + sigma2 - 2 * covmean)
return fid.item()
def batch_cosine_similarity(embedding_image, embedding_frames, device="cuda"):
embedding_image = torch.tensor(embedding_image).to(device)
embedding_frames = torch.tensor(embedding_frames).to(device)
return torch.nn.functional.cosine_similarity(embedding_image, embedding_frames, dim=-1).cpu().numpy()
def get_activations(images, model, batch_size=16):
model.eval()
activations = []
with torch.no_grad():
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
pred = model(batch)
activations.append(pred)
activations = torch.cat(activations, dim=0).cpu().numpy()
if activations.shape[0] == 1:
activations = np.repeat(activations, 2, axis=0)
return activations
def pad_np_bgr_image(np_image, scale=1.25):
assert scale >= 1.0, "scale should be >= 1.0"
pad_scale = scale - 1.0
h, w = np_image.shape[:2]
top = bottom = int(h * pad_scale)
left = right = int(w * pad_scale)
return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top)
def process_image(face_model, image_path):
if isinstance(image_path, str):
np_faceid_image = np.array(Image.open(image_path).convert("RGB"))
elif isinstance(image_path, numpy.ndarray):
np_faceid_image = image_path
else:
raise TypeError("image_path should be a string or PIL.Image.Image object")
image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
face_info = get_face_keypoints(face_model, image_bgr)
if face_info is None:
padded_image, sub_coord = pad_np_bgr_image(image_bgr)
face_info = get_face_keypoints(face_model, padded_image)
if face_info is None:
print("Warning: No face detected in the image. Continuing processing...")
return None, None
face_kps = face_info['kps']
face_kps -= np.array(sub_coord)
else:
face_kps = face_info['kps']
arcface_embedding = face_info['embedding']
# print(face_kps)
norm_face = face_align.norm_crop(image_bgr, landmark=face_kps, image_size=224)
align_face = cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB)
return align_face, arcface_embedding
@torch.no_grad()
def inference(face_model, img, device):
img = cv2.resize(img, (112, 112))
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float().to(device)
img.div_(255).sub_(0.5).div_(0.5)
embedding = face_model(img).detach().cpu().numpy()[0]
return embedding / np.linalg.norm(embedding)
def process_video(video_path, face_arc_model, face_cur_model, fid_model, arcface_image_embedding, cur_image_embedding, real_activations, device):
video_frames = sample_video_frames(video_path, num_frames=16)
#print(video_frames)
# Initialize lists to store the scores
cur_scores = []
arc_scores = []
fid_face = []
for frame in video_frames:
# Convert to RGB once at the beginning
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the frame for ArcFace embeddings
align_face_frame, arcface_frame_embedding = process_image(face_arc_model, frame_rgb)
# Skip if alignment fails
if align_face_frame is None:
continue
# Perform inference for current face model
cur_embedding_frame = inference(face_cur_model, align_face_frame, device)
# Compute cosine similarity for cur_score and arc_score in a compact manner
cur_score = max(0.0, batch_cosine_similarity(cur_image_embedding, cur_embedding_frame, device=device).item())
arc_score = max(0.0, batch_cosine_similarity(arcface_image_embedding, arcface_frame_embedding, device=device).item())
# Process FID score
align_face_frame_pil = Image.fromarray(align_face_frame)
fake_image = load_image(align_face_frame_pil).to(device)
fake_activations = get_activations(fake_image, fid_model)
fid_score = calculate_fid(real_activations, fake_activations, device)
# Collect scores
fid_face.append(fid_score)
cur_scores.append(cur_score)
arc_scores.append(arc_score)
# Aggregate results with default values for empty lists
avg_cur_score = np.mean(cur_scores) if cur_scores else 0.0
avg_arc_score = np.mean(arc_scores) if arc_scores else 0.0
avg_fid_score = np.mean(fid_face) if fid_face else 0.0
return avg_cur_score, avg_arc_score, avg_fid_score
def main():
device = "cuda"
# data_path = "data/SkyActor"
# data_path = "data/LivePotraits"
# data_path = "data/Actor-One"
data_path = "data/FollowYourEmoji"
img_path = "/maindata/data/shared/public/rui.wang/act_review/ref_images"
pre_tag = False
mp4_list = os.listdir(data_path)
print(mp4_list)
img_list = []
video_list = []
for mp4 in mp4_list:
if "mp4" not in mp4:
continue
if pre_tag:
png_path = mp4.split('.')[0].split('-')[0] + ".png"
else:
if "-" in mp4:
png_path = mp4.split('.')[0].split('-')[1] + ".png"
else:
png_path = mp4.split('.')[0].split('_')[1] + ".png"
img_list.append(os.path.join(img_path, png_path))
video_list.append(os.path.join(data_path, mp4))
print(img_list)
print(video_list[0])
model_path = "eval"
face_arc_path = os.path.join(model_path, "face_encoder")
face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin")
# Initialize FaceEncoder model for face detection and embedding extraction
face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider'])
face_arc_model.prepare(ctx_id=0, det_size=(320, 320))
# Load face recognition model
face_cur_model = get_model('IR_101')([112, 112])
face_cur_model.load_state_dict(torch.load(face_cur_path, map_location="cpu"))
face_cur_model = face_cur_model.to(device)
face_cur_model.eval()
# Load InceptionV3 model for FID calculation
fid_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
fid_model.fc = torch.nn.Identity() # Remove final classification layer
fid_model.eval()
fid_model = fid_model.to(device)
# Process the single video and image pair
# Extract embeddings and features from the image
cur_list, arc_list, fid_list = [], [], []
for i in range(len(img_list)):
align_face_image, arcface_image_embedding = process_image(face_arc_model, img_list[i])
cur_image_embedding = inference(face_cur_model, align_face_image, device)
align_face_image_pil = Image.fromarray(align_face_image)
real_image = load_image(align_face_image_pil).to(device)
real_activations = get_activations(real_image, fid_model)
# Process the video and calculate scores
cur_score, arc_score, fid_score = process_video(
video_list[i], face_arc_model, face_cur_model, fid_model,
arcface_image_embedding, cur_image_embedding, real_activations, device
)
print(cur_score, arc_score, fid_score)
cur_list.append(cur_score)
arc_list.append(arc_score)
fid_list.append(fid_score)
# break
print("cur", sum(cur_list)/ len(cur_list))
print("arc", sum(arc_list)/ len(arc_list))
print("fid", sum(fid_list)/ len(fid_list))
main()