import os # PyTorch 2.8 (temporary hack) os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') # --- 1. Model Download and Setup (Diffusers Backend) --- import spaces import torch from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers.utils.export_utils import export_to_video import gradio as gr import tempfile import numpy as np from PIL import Image import random import gc # Import the optimization function from the separate file from optimization import optimize_pipeline_ # --- Constants and Model Loading --- MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" # --- NEW: Flexible Dimension Constants --- MAX_DIMENSION = 832 MIN_DIMENSION = 480 DIMENSION_MULTIPLE = 16 SQUARE_SIZE = 480 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 81 MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1) MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1) default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝," print("Loading models into memory. This may take a few minutes...") pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', subfolder='transformer', torch_dtype=torch.bfloat16, device_map='cuda', ), transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', subfolder='transformer_2', torch_dtype=torch.bfloat16, device_map='cuda', ), torch_dtype=torch.bfloat16, ) pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0) pipe.to('cuda') print("Optimizing pipeline...") for i in range(3): gc.collect() torch.cuda.synchronize() torch.cuda.empty_cache() # Calling the imported optimization function with a placeholder image for compilation tracing optimize_pipeline_(pipe, image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), # Use representative dims prompt='prompt', height=MIN_DIMENSION, width=MAX_DIMENSION, num_frames=MAX_FRAMES_MODEL, ) print("All models loaded and optimized. Gradio app is ready.") # --- 2. Image Processing and Application Logic --- def process_image_for_video(image: Image.Image) -> Image.Image: """ Resizes an image based on the following rules for video generation: 1. The longest side will be scaled down to MAX_DIMENSION if it's larger. 2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller. 3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE. 4. Square images are resized to a fixed SQUARE_SIZE. The aspect ratio is preserved as closely as possible. """ width, height = image.size # Rule 4: Handle square images if width == height: return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS) # Determine target dimensions while preserving aspect ratio aspect_ratio = width / height new_width, new_height = width, height # Rule 1: Scale down if too large if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION: if aspect_ratio > 1: # Landscape scale = MAX_DIMENSION / new_width else: # Portrait scale = MAX_DIMENSION / new_height new_width *= scale new_height *= scale # Rule 2: Scale up if too small if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION: if aspect_ratio > 1: # Landscape scale = MIN_DIMENSION / new_height else: # Portrait scale = MIN_DIMENSION / new_width new_width *= scale new_height *= scale # Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) # Ensure final dimensions are at least the minimum final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE) final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE) return image.resize((final_width, final_height), Image.Resampling.LANCZOS) def resize_and_crop_to_match(target_image, reference_image): """Resizes and center-crops the target image to match the reference image's dimensions.""" ref_width, ref_height = reference_image.size target_width, target_height = target_image.size scale = max(ref_width / target_width, ref_height / target_height) new_width, new_height = int(target_width * scale), int(target_height * scale) resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 return resized.crop((left, top, left + ref_width, top + ref_height)) @spaces.GPU(duration=120) def generate_video( start_image_pil, end_image_pil, prompt, negative_prompt=default_negative_prompt, duration_seconds=2.1, steps=8, guidance_scale=1, guidance_scale_2=1, seed=42, randomize_seed=False, progress=gr.Progress(track_tqdm=True) ): """ Generates a video by interpolating between a start and end image, guided by a text prompt, using the diffusers Wan2.2 pipeline. """ if start_image_pil is None or end_image_pil is None: raise gr.Error("Please upload both a start and an end image.") progress(0.1, desc="Preprocessing images...") # Step 1: Process the start image to get our target dimensions based on the new rules. processed_start_image = process_image_for_video(start_image_pil) # Step 2: Make the end image match the *exact* dimensions of the processed start image. processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image) target_height, target_width = processed_start_image.height, processed_start_image.width # Handle seed and frame count current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...") output_frames_list = pipe( image=processed_start_image, last_image=processed_end_image, prompt=prompt, negative_prompt=negative_prompt, height=target_height, width=target_width, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed), ).frames[0] progress(0.9, desc="Encoding and saving video...") with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name export_to_video(output_frames_list, video_path, fps=FIXED_FPS) progress(1.0, desc="Done!") return video_path, current_seed # --- 3. Gradio User Interface --- (No changes needed here) css = ''' .fillable{max-width: 1100px !important} .dark .progress-text {color: white} ''' with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA") with gr.Row(): with gr.Column(): with gr.Group(): with gr.Row(): start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"]) end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"]) prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") with gr.Accordion("Advanced Settings", open=False): duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps") guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise") guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise") with gr.Row(): seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) generate_button = gr.Button("Generate Video", variant="primary") with gr.Column(): output_video = gr.Video(label="Generated Video", autoplay=True) # Define the inputs list for the click event ui_inputs = [ start_image, end_image, prompt, negative_prompt_input, duration_seconds_input, steps_slider, guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox ] # The seed_input is both an input and an output to reflect the randomly generated seed ui_outputs = [output_video, seed_input] generate_button.click( fn=generate_video, inputs=ui_inputs, outputs=ui_outputs ) gr.Examples( examples=[ ["poli_tower.png", "tower_takes_off.png", "the man turns around"], ["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], ["capyabara_zoomed.png", "capyabara.webp", "a dramatic dolly zoom"], ], inputs=[start_image, end_image, prompt], outputs=ui_outputs, fn=generate_video, cache_examples="lazy", ) if __name__ == "__main__": app.launch(share=True)