import argparse import os import random from datetime import datetime from pathlib import Path from typing import List import av import cv2 import numpy as np import torch # 初始化模型 import torchvision from diffusers import AutoencoderKL, DDIMScheduler from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPVisionModelWithProjection, ) import sys from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_lmks2vid_long import Pose2VideoPipeline from src.models.pose_guider import PoseGuider from src.utils.util import get_fps, read_frames, save_videos_grid from tools.facetracker_api import face_image def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, help="Path of inference configs", default="./configs/prompts/inference_reenact.yaml" ) parser.add_argument( "--save_dir", type=str, help="Path of save results", default="./output/stage2_infer" ) parser.add_argument( "--source_image_path", type=str, help="Path of source image", default="", ) parser.add_argument( "--driving_video_path", type=str, help="Path of driving video", default="", ) parser.add_argument( "--batch_size", type=int, default=320, help="Checkpoint step of pretrained model", ) parser.add_argument("--mask_ratio", type=float, default=0.55) # 0.55~0.6 parser.add_argument("-W", type=int, default=512) parser.add_argument("-H", type=int, default=512) parser.add_argument("-L", type=int, default=24) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cfg", type=float, default=3.5) parser.add_argument("--steps", type=int, default=30) parser.add_argument("--fps", type=int, default=25) args = parser.parse_args() return args def lmks_vis(img, lms): # Visualize the mouth, nose, and entire face based on landmarks h, w, c = img.shape lms = lms[:, :2] mouth = lms[48:66] nose = lms[27:36] color = (0, 255, 0) # Center mouth and nose x_c, y_c = np.mean(lms[:, 0]), np.mean(lms[:, 1]) h_c, w_c = h // 2, w // 2 img_face, img_mouth, img_nose = img.copy(), img.copy(), img.copy() for pt_num, (x, y) in enumerate(mouth): x = x - (x_c - w_c) y = y - (y_c - h_c) x = int(x + 0.5) y = int(y + 0.5) cv2.circle(img_mouth, (y, x), 1, color, -1) for pt_num, (x, y) in enumerate(nose): x = x - (x_c - w_c) y = y - (y_c - h_c) x = int(x + 0.5) y = int(y + 0.5) cv2.circle(img_nose, (y, x), 1, color, -1) for pt_num, (x, y) in enumerate(lms): x = int(x + 0.5) y = int(y + 0.5) if pt_num >= 66: color = (255, 255, 0) else: color = (0, 255, 0) cv2.circle(img_face, (y, x), 1, color, -1) return img_face, img_mouth, img_nose def batch_rearrange(pose_len, batch_size=24): # To rearrange the pose sequence based on batch size batch_ind_list = [] for i in range(0, pose_len, batch_size): if i + batch_size < pose_len: batch_ind_list.append(list(range(i, i + batch_size))) else: batch_ind_list.append(list(range(i, min(i + batch_size, pose_len)))) return batch_ind_list def lmks_video_extract(video_path): # To extract the landmark sequence of video (single face video) video_stream = cv2.VideoCapture(video_path) lmks_list, frames = [], [] while 1: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() break h, w, c = frame.shape lmk_img, lmks = face_image(frame) if lmks is not None: lmks_list.append(lmks) frames.append(frame) return frames, np.array(lmks_list), [h, w] def adjust_pose(src_lms_list, src_size, ref_lms, ref_size): # To align the center of source landmarks based on reference landmark new_src_lms_list = [] ref_lms = ref_lms[:, :2] src_lms = src_lms_list[0][:, :2] ref_lms[:, 0] = ref_lms[:, 0] / ref_size[1] ref_lms[:, 1] = ref_lms[:, 1] / ref_size[0] src_lms[:, 0] = src_lms[:, 0] / src_size[1] src_lms[:, 1] = src_lms[:, 1] / src_size[0] ref_cx, ref_cy = np.mean(ref_lms[:, 0]), np.mean(ref_lms[:, 1]) src_cx, src_cy = np.mean(src_lms[:, 0]), np.mean(src_lms[:, 1]) for item in src_lms_list: item = item[:, :2] item[:, 0] = item[:, 0] - int((src_cx - ref_cx)) * src_size[1] item[:, 1] = item[:, 1] - int((src_cy - ref_cy)) * src_size[0] new_src_lms_list.append(item) return np.array(new_src_lms_list) def main(): args = parse_args() infer_config = OmegaConf.load(args.config) # base_model_path = "./pretrained_weights/huggingface-models/sd-image-variations-diffusers/" base_model_path = infer_config.pretrained_base_model_path weight_dtype = torch.float16 image_enc = CLIPVisionModelWithProjection.from_pretrained( # "./pretrained_weights/huggingface-models/sd-image-variations-diffusers/image_encoder" infer_config.image_encoder_path ).to(dtype=weight_dtype, device="cuda") vae = AutoencoderKL.from_pretrained( # "./pretrained_weights/huggingface-models/sd-vae-ft-mse" infer_config.pretrained_vae_path ).to("cuda", dtype=weight_dtype) # initial reference unet, denoise unet, pose guider reference_unet = UNet3DConditionModel.from_pretrained_2d( base_model_path, "", subfolder="unet", unet_additional_kwargs={ "task_type": "reenact", "use_motion_module": False, "unet_use_temporal_attention": False, "mode": "write", }, ).to(device="cuda", dtype=weight_dtype) denoising_unet = UNet3DConditionModel.from_pretrained_2d( base_model_path, "./pretrained_weights/mm_sd_v15_v2.ckpt", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container( infer_config.unet_additional_kwargs ), # mm_zero_proj_out=True, ).to(device="cuda") pose_guider1 = PoseGuider( conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) ).to(device="cuda", dtype=weight_dtype) pose_guider2 = PoseGuider( conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) ).to(device="cuda", dtype=weight_dtype) print("------------------initial all networks------------------") # load model from pretrained models denoising_unet.load_state_dict( torch.load( infer_config.denoising_unet_path, map_location="cpu", ), strict=True, ) reference_unet.load_state_dict( torch.load( infer_config.reference_unet_path, map_location="cpu", ) ) pose_guider1.load_state_dict( torch.load( infer_config.pose_guider1_path, map_location="cpu", ) ) pose_guider2.load_state_dict( torch.load( infer_config.pose_guider2_path, map_location="cpu", ) ) print("---------load pretrained denoising unet, reference unet and pose guider----------") # scheduler enable_zero_snr = True sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) if enable_zero_snr: sched_kwargs.update( rescale_betas_zero_snr=True, timestep_spacing="trailing", prediction_type="v_prediction", ) scheduler = DDIMScheduler(**sched_kwargs) pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider1=pose_guider1, pose_guider2=pose_guider2, scheduler=scheduler, ) pipe = pipe.to("cuda", dtype=weight_dtype) height, width, clip_length = args.H, args.W, args.L generator = torch.manual_seed(42) date_str = datetime.now().strftime("%Y%m%d") save_dir = Path(f"{args.save_dir}/{date_str}") save_dir.mkdir(exist_ok=True, parents=True) ref_image_path, pose_video_path = args.source_image_path, args.driving_video_path ref_name = Path(ref_image_path).stem pose_name = Path(pose_video_path).stem ref_image_pil = Image.open(ref_image_path).convert("RGB") ref_image = cv2.imread(ref_image_path) ref_h, ref_w, c = ref_image.shape ref_pose, ref_pose_lms = face_image(ref_image) # To extract landmarks from driving video pose_frames, pose_lms_list, pose_size = lmks_video_extract(pose_video_path) pose_lms_list = adjust_pose(pose_lms_list, pose_size, ref_pose_lms, [ref_h, ref_w]) pose_h, pose_w = int(pose_size[0]), int(pose_size[1]) pose_len = pose_lms_list.shape[0] # Truncating the video tail if its frames less than 24 to obtain stable effect. pose_len = pose_len // 24 * 24 batch_index_list = batch_rearrange(pose_len, args.batch_size) pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) videos = [] zero_map = np.zeros_like(ref_pose) zero_map = cv2.resize(zero_map, (pose_w, pose_h)) for batch_index in batch_index_list: pose_list, pose_up_list, pose_down_list = [], [], [] pose_frame_list = [] pose_tensor_list, pose_up_tensor_list, pose_down_tensor_list = [], [], [] batch_len = len(batch_index) for pose_idx in batch_index: pose_lms = pose_lms_list[pose_idx] pose_frame = pose_frames[pose_idx][:, :, ::-1] pose_image, pose_mouth_image, _ = lmks_vis(zero_map, pose_lms) h, w, c = pose_image.shape pose_up_image = pose_image.copy() pose_up_image[int(h * args.mask_ratio):, :, :] = 0. pose_image_pil = Image.fromarray(pose_image) pose_frame = Image.fromarray(pose_frame) pose_up_pil = Image.fromarray(pose_up_image) pose_mouth_pil = Image.fromarray(pose_mouth_image) pose_list.append(pose_image_pil) pose_up_list.append(pose_up_pil) pose_down_list.append(pose_mouth_pil) pose_tensor_list.append(pose_transform(pose_image_pil)) pose_up_tensor_list.append(pose_transform(pose_up_pil)) pose_down_tensor_list.append(pose_transform(pose_mouth_pil)) pose_frame_list.append(pose_transform(pose_frame)) pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) pose_tensor = pose_tensor.transpose(0, 1) pose_tensor = pose_tensor.unsqueeze(0) pose_frames_tensor = torch.stack(pose_frame_list, dim=0) # (f, c, h, w) pose_frames_tensor = pose_frames_tensor.transpose(0, 1) pose_frames_tensor = pose_frames_tensor.unsqueeze(0) ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) ref_image_tensor = repeat( ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=batch_len ) # To disentangle head attitude control (including eyes blink) and mouth motion control pipeline_output = pipe( ref_image_pil, pose_up_list, pose_down_list, width, height, batch_len, 20, 3.5, generator=generator, ) video = pipeline_output.videos video = torch.cat([ref_image_tensor, pose_frames_tensor, video], dim=0) videos.append(video) videos = torch.cat(videos, dim=2) time_str = datetime.now().strftime("%H%M") save_video_path = f"{save_dir}/{ref_name}_{pose_name}_{time_str}.mp4" save_videos_grid( videos, save_video_path, n_rows=3, fps=args.fps, ) print("infer results: {}".format(save_video_path)) del pipe torch.cuda.empty_cache() if __name__ == "__main__": main()