import os import sys import shutil import uuid import subprocess from glob import glob from huggingface_hub import snapshot_download import tempfile from moviepy.editor import VideoFileClip from pydub import AudioSegment import argparse from omegaconf import OmegaConf import torch from diffusers import AutoencoderKL, DDIMScheduler from latentsync.models.unet import UNet3DConditionModel from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline from diffusers.utils.import_utils import is_xformers_available from accelerate.utils import set_seed from latentsync.whisper.audio2feature import Audio2Feature import random import requests # --- 模型下载 --- print("确保模型已下载...") os.makedirs("checkpoints", exist_ok=True) snapshot_download( repo_id="ByteDance/LatentSync", local_dir="./checkpoints", ignore_patterns=["*.pth", "*.pt", "*.bin"] # 仅在第一次运行时下载所有文件,之后可忽略以加快启动速度 ) print("模型加载完毕。") # --- 辅助函数 --- def download_file(url, save_dir='downloads'): """从 URL 下载文件。""" os.makedirs(save_dir, exist_ok=True) filename = url.split('/')[-1].split('?')[0] # 清理文件名 local_path = os.path.join(save_dir, filename) print(f"正在从 {url} 下载到 {local_path}...") try: with requests.get(url, stream=True) as r: r.raise_for_status() with open(local_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) print("下载完成。") return local_path except requests.exceptions.RequestException as e: print(f"下载文件时出错: {e}") sys.exit(1) def process_video(input_video_path, temp_dir): """如果视频超过10秒,则裁剪到10秒。""" os.makedirs(temp_dir, exist_ok=True) video = VideoFileClip(input_video_path) input_file_name = os.path.basename(input_video_path) output_video_path = os.path.join(temp_dir, f"processed_{input_file_name}") if video.duration > 10: print(f"视频时长 ({video.duration:.2f}s) 超过10秒,正在裁剪...") video = video.subclip(0, 10) video.write_videofile(output_video_path, codec="libx264", audio_codec="aac", logger=None) print(f"视频已处理并保存到: {output_video_path}") return output_video_path def process_audio(file_path, temp_dir): """如果音频超过8秒,则裁剪到8秒。""" audio = AudioSegment.from_file(file_path) max_duration = 8 * 1000 # 8 秒 if len(audio) > max_duration: print(f"音频时长 ({len(audio)/1000:.2f}s) 超过8秒,正在裁剪...") audio = audio[:max_duration] output_path = os.path.join(temp_dir, "processed_audio.wav") audio.export(output_path, format="wav") print(f"音频已处理并保存到: {output_path}") return output_path # --- 核心推理函数 --- def run_inference(video_path, audio_path): """执行口型同步推理的核心逻辑。""" inference_ckpt_path = "checkpoints/latentsync_unet.pt" unet_config_path = "configs/unet/second_stage.yaml" config = OmegaConf.load(unet_config_path) print(f"输入视频路径: {video_path}") print(f"输入音频路径: {audio_path}") # 创建临时目录来处理音视频文件 temp_dir = tempfile.mkdtemp() try: # 预处理音视频,确保它们符合模型要求 processed_video_path = process_video(video_path, temp_dir) processed_audio_path = process_audio(audio_path, temp_dir) scheduler = DDIMScheduler.from_pretrained("configs") if config.model.cross_attention_dim == 768: whisper_model_path = "checkpoints/whisper/small.pt" elif config.model.cross_attention_dim == 384: whisper_model_path = "checkpoints/whisper/tiny.pt" else: raise NotImplementedError("cross_attention_dim 必须是 768 或 384") audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) unet, _ = UNet3DConditionModel.from_pretrained( OmegaConf.to_container(config.model), inference_ckpt_path, device="cpu", ) unet = unet.to(dtype=torch.float16) if is_xformers_available(): print("启用 xformers 内存优化。") unet.enable_xformers_memory_efficient_attention() pipeline = LipsyncPipeline( vae=vae, audio_encoder=audio_encoder, unet=unet, scheduler=scheduler, ).to("cuda") seed = -1 if seed != -1: set_seed(seed) else: torch.seed() print(f"初始种子: {torch.initial_seed()}") # 生成唯一的文件名 unique_id = str(uuid.uuid4())[:8] video_out_path = f"output_{unique_id}.mp4" print("🚀 开始生成视频...") pipeline( video_path=processed_video_path, audio_path=processed_audio_path, video_out_path=video_out_path, video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"), num_frames=config.data.num_frames, num_inference_steps=config.run.inference_steps, guidance_scale=1.0, weight_dtype=torch.float16, width=config.data.resolution, height=config.data.resolution, ) return video_out_path finally: # 清理临时目录 if os.path.exists(temp_dir): shutil.rmtree(temp_dir) print(f"临时目录 {temp_dir} 已删除。") # --- 主程序入口 --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="LatentSync: 命令行口型同步工具") parser.add_argument("--input_video", type=str, help="输入视频的路径或URL。") parser.add_argument("--input_audio", type=str, help="输入音频的路径或URL。") args = parser.parse_args() final_video_path = args.input_video final_audio_path = args.input_audio # 如果用户没有提供输入,则从 assets 文件夹中随机选择 if not final_video_path or not final_audio_path: print("⚠️ 未指定输入文件。将从 'assets' 文件夹中随机选择一个示例。") asset_videos = glob("assets/demo*_video.mp4") if not asset_videos: print("错误: 'assets' 文件夹中未找到示例视频(例如 demo1_video.mp4)。") sys.exit(1) selected_video = random.choice(asset_videos) base_name = os.path.basename(selected_video).replace('_video.mp4', '') selected_audio = os.path.join('assets', f"{base_name}_audio.wav") if not os.path.exists(selected_audio): print(f"错误: 找不到与 {selected_video} 匹配的音频文件 {selected_audio}。") sys.exit(1) final_video_path = selected_video final_audio_path = selected_audio print(f"已随机选择: Video='{final_video_path}', Audio='{final_audio_path}'") else: # 如果提供了输入,检查是否是 URL 并下载 if final_video_path.startswith(('http://', 'https://')): final_video_path = download_file(final_video_path) if final_audio_path.startswith(('http://', 'https://')): final_audio_path = download_file(final_audio_path) # 检查文件是否存在 if not os.path.exists(final_video_path): print(f"错误: 视频文件不存在于 '{final_video_path}'") sys.exit(1) if not os.path.exists(final_audio_path): print(f"错误: 音频文件不存在于 '{final_audio_path}'") sys.exit(1) # 运行推理 output_file = run_inference(final_video_path, final_audio_path) print(f"✅ 视频生成成功!输出文件位于: {output_file}")