import os import shutil import sys import subprocess import asyncio import uuid import random import tempfile from typing import Sequence, Mapping, Any, Union import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download import spaces # --- 1. Model Download and Setup --- def hf_hub_download_local(repo_id, filename, local_dir, **kwargs): """Downloads a file from Hugging Face Hub and symlinks it to a local directory.""" downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) os.makedirs(local_dir, exist_ok=True) base_filename = os.path.basename(filename) target_path = os.path.join(local_dir, base_filename) # Remove existing symlink or file to avoid errors if os.path.exists(target_path) or os.path.islink(target_path): os.remove(target_path) os.symlink(downloaded_path, target_path) return target_path print("Downloading models from Hugging Face Hub...") hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders") hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae") hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision") hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras") hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras") print("Downloads complete.") # --- 2. ComfyUI Backend Initialization --- def find_path(name: str, path: str = None) -> str: """Recursively finds a directory with a given name.""" if path is None: path = os.getcwd() if name in os.listdir(path): return os.path.join(path, name) parent_directory = os.path.dirname(path) return find_path(name, parent_directory) if parent_directory != path else None def add_comfyui_directory_to_sys_path() -> None: """Adds the ComfyUI directory to sys.path for imports.""" comfyui_path = find_path("ComfyUI") if comfyui_path and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: """Initializes ComfyUI's folder_paths with custom paths.""" from main import apply_custom_paths apply_custom_paths() def import_custom_nodes() -> None: """Initializes all ComfyUI custom nodes.""" import nodes loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(nodes.init_extra_nodes(init_custom_nodes=True)) print("Setting up ComfyUI paths and nodes...") add_comfyui_directory_to_sys_path() add_extra_model_paths() import_custom_nodes() print("ComfyUI setup complete.") # --- 3. Global Model & Node Loading and Patching --- from nodes import NODE_CLASS_MAPPINGS import folder_paths from comfy import model_management # Set VRAM mode to HIGH to prevent models from being offloaded from GPU after use. # model_management.vram_state = model_management.VRAMState.HIGH_VRAM MODELS_AND_NODES = {} def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: """Helper to safely access outputs from ComfyUI nodes, which are often tuples.""" try: return obj[index] except (KeyError, TypeError): # Fallback for custom nodes that might return a dictionary with a 'result' key if isinstance(obj, Mapping) and "result" in obj: return obj["result"][index] raise print("Loading models and instantiating nodes into memory. This may take a few minutes...") # Instantiate Node Classes that will be used for loading and patching cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]() unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]() pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]() # Load base models into CPU RAM initially MODELS_AND_NODES["clip"] = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan") unet_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default") unet_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default") MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors") MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors") # Chain all patching operations together for the final models print("Applying all patches to models...") # --- Low Noise Model Chain --- model_low_with_lora = loraloadermodelonly.load_lora_model_only( lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unet_low_noise, 0)) model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_with_lora, 0)) MODELS_AND_NODES["model_low_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0)) # --- High Noise Model Chain --- model_high_with_lora = loraloadermodelonly.load_lora_model_only( lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unet_high_noise, 0)) model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_with_lora, 0)) MODELS_AND_NODES["model_high_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0)) # Instantiate all other node classes ONCE and store them MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]() MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]() MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]() MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]() MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]() MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]() # Move all final, fully-patched models to the GPU print("Moving final models to GPU...") model_loaders_final = [ MODELS_AND_NODES["clip"], # MODELS_AND_NODES["vae"], MODELS_AND_NODES["model_low_noise"], MODELS_AND_NODES["model_high_noise"], MODELS_AND_NODES["clip_vision"], ] model_management.load_models_gpu([ loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders_final ], force_patch_weights=True) # force_patch_weights permanently merges the LoRA print("All models loaded, patched, and on GPU. Gradio app is ready.") # --- 4. Application Logic and Gradio Interface --- def calculate_video_dimensions(width, height, max_size=832, min_size=480): """Calculates video dimensions, ensuring they are multiples of 16.""" if width == height: return min_size, min_size aspect_ratio = width / height if width > height: video_width = max_size video_height = int(max_size / aspect_ratio) else: video_height = max_size video_width = int(max_size * aspect_ratio) video_width = max(16, round(video_width / 16) * 16) video_height = max(16, round(video_height / 16) * 16) return video_width, video_height def resize_and_crop_to_match(target_image, reference_image): """Resizes and center-crops the target image to match the reference image's dimensions.""" ref_width, ref_height = reference_image.size target_width, target_height = target_image.size scale = max(ref_width / target_width, ref_height / target_height) new_width, new_height = int(target_width * scale), int(target_height * scale) resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 return resized.crop((left, top, left + ref_width, top + ref_height)) @spaces.GPU(duration=120) def generate_video( start_image_pil, end_image_pil, prompt, negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", duration=33, progress=gr.Progress(track_tqdm=True) ): """ Generates a video by interpolating between a start and end image, guided by a text prompt. This function relies on globally pre-loaded models and pre-instantiated ComfyUI nodes. """ FPS = 16 # --- 1. Retrieve Pre-loaded and Pre-patched Models & Node Instances --- # These are not re-instantiated; we are just getting references to the global objects. clip = MODELS_AND_NODES["clip"] vae = MODELS_AND_NODES["vae"] model_low_final = MODELS_AND_NODES["model_low_noise"] model_high_final = MODELS_AND_NODES["model_high_noise"] clip_vision = MODELS_AND_NODES["clip_vision"] cliptextencode = MODELS_AND_NODES["CLIPTextEncode"] loadimage = MODELS_AND_NODES["LoadImage"] clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"] wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"] ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"] vaedecode = MODELS_AND_NODES["VAEDecode"] createvideo = MODELS_AND_NODES["CreateVideo"] savevideo = MODELS_AND_NODES["SaveVideo"] # --- 2. Image Preprocessing for the Current Run --- print("Preprocessing images with Pillow...") processed_start_image = start_image_pil.copy() processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil) video_width, video_height = calculate_video_dimensions(processed_start_image.width, processed_start_image.height) # Save processed images to temporary files for the LoadImage node temp_dir = "input" # ComfyUI's default input directory os.makedirs(temp_dir, exist_ok=True) with tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as start_file, \ tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as end_file: processed_start_image.save(start_file.name) processed_end_image.save(end_file.name) start_image_path = os.path.basename(start_file.name) end_image_path = os.path.basename(end_file.name) print(f"Images resized to {video_width}x{video_height} and saved temporarily.") # --- 3. Execute the ComfyUI Workflow in Inference Mode --- with torch.inference_mode(): progress(0.1, desc="Encoding text and images...") # Encode prompts and vision models positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0)) negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0)) start_image_loaded = loadimage.load_image(image=start_image_path) end_image_loaded = loadimage.load_image(image=end_image_path) clip_vision_encoded_start = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)) clip_vision_encoded_end = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)) progress(0.2, desc="Preparing initial latents...") initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED( width=video_width, height=video_height, length=duration, batch_size=1, positive=get_value_at_index(positive_conditioning, 0), negative=get_value_at_index(negative_conditioning, 0), vae=get_value_at_index(vae, 0), clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0), clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0), start_image=get_value_at_index(start_image_loaded, 0), end_image=get_value_at_index(end_image_loaded, 0), ) ksampler_positive = get_value_at_index(initial_latents, 0) ksampler_negative = get_value_at_index(initial_latents, 1) ksampler_latent = get_value_at_index(initial_latents, 2) progress(0.5, desc="Denoising (Step 1/2)...") latent_step1 = ksampleradvanced.sample( add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0), positive=ksampler_positive, negative=ksampler_negative, latent_image=ksampler_latent, ) progress(0.7, desc="Denoising (Step 2/2)...") latent_step2 = ksampleradvanced.sample( add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0), positive=ksampler_positive, negative=ksampler_negative, latent_image=get_value_at_index(latent_step1, 0), ) progress(0.8, desc="Decoding VAE...") decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0)) progress(0.9, desc="Creating and saving video...") video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0)) # Save the video to ComfyUI's default output directory save_result = savevideo.save_video( filename_prefix="GradioVideo", format="mp4", codec="h264", video=get_value_at_index(video_data, 0), ) progress(1.0, desc="Done!") # --- 4. Cleanup and Return --- try: os.remove(start_file.name) os.remove(end_file.name) except Exception as e: print(f"Error cleaning up temporary files: {e}") # Gradio video component expects a filepath relative to the root of the app return f"output/{save_result['ui']['images'][0]['filename']}" css = ''' .fillable{max-width: 1100px !important} .dark .progress-text {color: white} ''' with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU") with gr.Row(): with gr.Column(): with gr.Group(): with gr.Row(): start_image = gr.Image(type="pil", label="Start Frame") end_image = gr.Image(type="pil", label="End Frame") prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") with gr.Accordion("Advanced Settings", open=False, visible=False): duration = gr.Radio( [("Short (2s)", 33), ("Mid (4s)", 66)], value=33, label="Video Duration", visible=False ) negative_prompt = gr.Textbox( label="Negative Prompt", value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", visible=False ) generate_button = gr.Button("Generate Video", variant="primary") with gr.Column(): output_video = gr.Video(label="Generated Video", autoplay=True) generate_button.click( fn=generate_video, inputs=[start_image, end_image, prompt, negative_prompt, duration], outputs=output_video ) gr.Examples( examples=[ ["poli_tower.png", "tower_takes_off.png", "the man turns around"], ["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], ["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"], ], inputs=[start_image, end_image, prompt], outputs=output_video, fn=generate_video, cache_examples="lazy", ) if __name__ == "__main__": app.launch(share=True)