Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import os | |
import numpy as np | |
from PIL import Image | |
import glob | |
import insightface | |
import cv2 | |
import subprocess | |
import argparse | |
from decord import VideoReader | |
from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip | |
from facexlib.parsing import init_parsing_model | |
from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
from insightface.app import FaceAnalysis | |
from diffusers.models import AutoencoderKLCogVideoX | |
from diffusers.utils import export_to_video, load_image | |
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel | |
from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline | |
from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor | |
from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor | |
from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d | |
import moviepy.editor as mp | |
from diffposetalk.diffposetalk import DiffPoseTalk | |
def crop_and_resize(image, height, width): | |
image = np.array(image) | |
image_height, image_width, _ = image.shape | |
if image_height / image_width < height / width: | |
croped_width = int(image_height / height * width) | |
left = (image_width - croped_width) // 2 | |
image = image[:, left: left+croped_width] | |
image = Image.fromarray(image).resize((width, height)) | |
else: | |
pad = int((((width / height) * image_height) - image_width) / 2.) | |
padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) | |
padded_image[:, pad:pad+image_width] = image | |
image = Image.fromarray(padded_image).resize((width, height)) | |
return image | |
def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"): | |
clip = ImageSequenceClip(samples, fps=fps) | |
clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, | |
ffmpeg_params=["-crf", "18", "-preset", "slow"]) | |
def parse_video(driving_frames, max_frame_num, fps=25): | |
video_length = len(driving_frames) | |
duration = video_length / fps | |
target_times = np.arange(0, duration, 1/12) | |
frame_indices = (target_times * fps).astype(np.int32) | |
frame_indices = frame_indices[frame_indices < video_length] | |
new_driving_frames = [] | |
for idx in frame_indices: | |
new_driving_frames.append(driving_frames[idx]) | |
if len(new_driving_frames) >= max_frame_num - 1: | |
break | |
video_lenght_add = max_frame_num - len(new_driving_frames) - 1 | |
new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add | |
return new_driving_frames | |
def save_video_with_audio(video_path:str, audio_path: str, save_path: str): | |
video_clip = mp.VideoFileClip(video_path) | |
audio_clip = mp.AudioFileClip(audio_path) | |
if audio_clip.duration > video_clip.duration: | |
audio_clip = audio_clip.subclip(0, video_clip.duration) | |
video_with_audio = video_clip.set_audio(audio_clip) | |
video_with_audio.write_videofile(save_path, fps=12) | |
os.remove(video_path) | |
video_clip.close() | |
audio_clip.close() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Process video and image for face animation.") | |
parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.') | |
parser.add_argument('--driving_audio_path', type=str, default="assets/driving_audio/1.wav", help='Path to the driving video.') | |
parser.add_argument('--output_path', type=str, default="outputs_audio", help='Path to save the output video.') | |
args = parser.parse_args() | |
guidance_scale = 3.0 | |
seed = 43 | |
num_inference_steps = 10 | |
sample_size = [480, 720] | |
max_frame_num = 49 | |
weight_dtype = torch.bfloat16 | |
save_path = args.output_path | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
model_name = "pretrained_models/SkyReels-A1-5B/" | |
siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" | |
lmk_extractor = LMKExtractor() | |
processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') | |
vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,) | |
face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",) | |
# siglip visual encoder | |
siglip = SiglipVisionModel.from_pretrained(siglip_name) | |
siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) | |
# diffposetalk | |
diffposetalk = DiffPoseTalk() | |
# skyreels a1 model | |
transformer = CogVideoXTransformer3DModel.from_pretrained( | |
model_name, | |
subfolder="transformer" | |
).to(weight_dtype) | |
vae = AutoencoderKLCogVideoX.from_pretrained( | |
model_name, | |
subfolder="vae" | |
).to(weight_dtype) | |
lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( | |
model_name, | |
subfolder="pose_guider", | |
).to(weight_dtype) | |
pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( | |
model_name, | |
transformer = transformer, | |
vae = vae, | |
lmk_encoder = lmk_encoder, | |
image_encoder = siglip, | |
feature_extractor = siglip_normalize, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.to("cuda") | |
pipe.enable_model_cpu_offload() | |
pipe.vae.enable_tiling() | |
image = load_image(image=args.image_path) | |
image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) | |
# ref image crop face | |
ref_image, x1, y1 = processor.face_crop(np.array(image)) | |
face_h, face_w, _, = ref_image.shape | |
source_image = ref_image | |
source_outputs, source_tform, image_original = processor.process_source_image(source_image) | |
driving_outputs = diffposetalk.infer_from_file(args.driving_audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy()) | |
out_frames = processor.preprocess_lmk3d_from_coef(source_outputs, source_tform, image_original.shape, driving_outputs) | |
out_frames = parse_video(out_frames, max_frame_num) | |
rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0) | |
for ii in range(rescale_motions.shape[0]): | |
rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] | |
ref_image = cv2.resize(ref_image, (512, 512)) | |
ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) | |
ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) | |
first_motion = np.zeros_like(np.array(image)) | |
first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img | |
first_motion = first_motion[np.newaxis, :] | |
motions = np.concatenate([first_motion, rescale_motions]) | |
input_video = motions[:max_frame_num] | |
face_helper.clean_all() | |
face_helper.read_image(np.array(image)[:, :, ::-1]) | |
face_helper.get_face_landmarks_5(only_center_face=True) | |
face_helper.align_warp_face() | |
align_face = face_helper.cropped_faces[0] | |
image_face = align_face[:, :, ::-1] | |
input_video = input_video[:max_frame_num] | |
motions = np.array(input_video) | |
# [F, H, W, C] | |
input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) | |
input_video = input_video / 255 | |
out_samples = [] | |
with torch.no_grad(): | |
sample = pipe( | |
image=image, | |
image_face=image_face, | |
control_video = input_video, | |
prompt = "", | |
negative_prompt = "", | |
height = sample_size[0], | |
width = sample_size[1], | |
num_frames = 49, | |
generator = generator, | |
guidance_scale = guidance_scale, | |
num_inference_steps = num_inference_steps, | |
) | |
out_samples.extend(sample.frames[0]) | |
out_samples = out_samples[2:] | |
save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_audio_path).split(".")[0]+ ".mp4" | |
if not os.path.exists(save_path): | |
os.makedirs(save_path, exist_ok=True) | |
video_path = os.path.join(save_path, save_path_name + ".output.mp4") | |
export_to_video(out_samples, video_path, fps=12) | |
target_h, target_w = sample_size[0], sample_size[1] | |
final_images = [] | |
final_images2 =[] | |
rescale_motions = rescale_motions[1:] | |
control_frames = out_frames[1:] | |
for q in range(len(out_samples)): | |
frame1 = image | |
frame2 = Image.fromarray(np.array(out_samples[q])).convert("RGB") | |
result = Image.new('RGB', (target_w * 2, target_h)) | |
result.paste(frame1, (0, 0)) | |
result.paste(frame2, (target_w, 0)) | |
final_images.append(np.array(result)) | |
video_out_path = os.path.join(save_path, save_path_name) | |
write_mp4(video_out_path, final_images, fps=12) | |
save_video_with_audio(video_out_path, args.driving_audio_path, video_out_path + ".audio.mp4") | |
save_video_with_audio(video_path, args.driving_audio_path, video_path + ".audio.mp4") | |