Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -1,68 +1,53 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            -
            from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
         
     | 
| 3 | 
         
             
            from diffusers.utils import export_to_video
         
     | 
| 4 | 
         
            -
            from transformers import CLIPVisionModel
         
     | 
| 5 | 
         
             
            import gradio as gr
         
     | 
| 6 | 
         
             
            import tempfile
         
     | 
| 7 | 
         
             
            import spaces
         
     | 
| 8 | 
         
            -
            from huggingface_hub import hf_hub_download
         
     | 
| 9 | 
         
             
            import numpy as np
         
     | 
| 10 | 
         
             
            from PIL import Image
         
     | 
| 11 | 
         
             
            import random
         
     | 
| 12 | 
         | 
| 13 | 
         
            -
            MODEL_ID = " 
     | 
| 14 | 
         
            -
            LORA_REPO_ID = "Kijai/WanVideo_comfy"
         
     | 
| 15 | 
         
            -
            LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
            image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
         
     | 
| 18 | 
         
             
            vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
         
     | 
| 19 | 
         
            -
            pipe = WanImageToVideoPipeline.from_pretrained(
         
     | 
| 20 | 
         
            -
                MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
         
     | 
| 21 | 
         
            -
            )
         
     | 
| 22 | 
         
            -
            pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
         
     | 
| 23 | 
         
            -
            pipe.to("cuda")
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 29 | 
         | 
| 
         | 
|
| 30 | 
         
             
            MOD_VALUE = 32
         
     | 
| 31 | 
         
            -
            DEFAULT_H_SLIDER_VALUE =  
     | 
| 32 | 
         
             
            DEFAULT_W_SLIDER_VALUE = 896
         
     | 
| 33 | 
         
            -
            NEW_FORMULA_MAX_AREA =  
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
            SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
         
     | 
| 37 | 
         
             
            MAX_SEED = np.iinfo(np.int32).max
         
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
             
            FIXED_FPS = 24
         
     | 
| 40 | 
         
            -
            MIN_FRAMES_MODEL =  
     | 
| 41 | 
         
            -
            MAX_FRAMES_MODEL =  
     | 
| 42 | 
         | 
| 43 | 
         
             
            default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
         
     | 
| 44 | 
         
             
            default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
         
     | 
| 45 | 
         | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
            def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
         
     | 
| 48 | 
         
            -
                                             min_slider_h, max_slider_h,
         
     | 
| 49 | 
         
            -
                                             min_slider_w, max_slider_w,
         
     | 
| 50 | 
         
            -
                                             default_h, default_w):
         
     | 
| 51 | 
         
             
                orig_w, orig_h = pil_image.size
         
     | 
| 52 | 
         
             
                if orig_w <= 0 or orig_h <= 0:
         
     | 
| 53 | 
         
             
                    return default_h, default_w
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
             
                aspect_ratio = orig_h / orig_w
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
             
                calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
         
     | 
| 58 | 
         
             
                calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
         
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
             
                calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
         
     | 
| 61 | 
         
             
                calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
         
     | 
| 62 | 
         
            -
             
     | 
| 63 | 
         
             
                new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
         
     | 
| 64 | 
         
             
                new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
             
                return new_h, new_w
         
     | 
| 67 | 
         | 
| 68 | 
         
             
            def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
         
     | 
| 
         @@ -78,85 +63,45 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_ 
     | 
|
| 78 | 
         
             
                except Exception as e:
         
     | 
| 79 | 
         
             
                    gr.Warning("Error attempting to calculate new dimensions")
         
     | 
| 80 | 
         
             
                    return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
             
            def get_duration(input_image, prompt, height, width, 
         
     | 
| 83 | 
         
             
                               negative_prompt, duration_seconds,
         
     | 
| 84 | 
         
             
                               guidance_scale, steps,
         
     | 
| 85 | 
         
             
                               seed, randomize_seed, 
         
     | 
| 86 | 
         
             
                               progress):
         
     | 
| 87 | 
         
            -
                if steps > 4 and duration_seconds >  
     | 
| 88 | 
         
             
                    return 90
         
     | 
| 89 | 
         
            -
                elif steps > 4 or duration_seconds >  
     | 
| 90 | 
         
             
                    return 75
         
     | 
| 91 | 
         
             
                else:
         
     | 
| 92 | 
         
             
                    return 60
         
     | 
| 93 | 
         | 
| 94 | 
         
             
            @spaces.GPU(duration=get_duration)
         
     | 
| 95 | 
         
            -
            def generate_video(input_image, prompt, height, width, 
         
     | 
| 96 | 
         
            -
                               negative_prompt=default_negative_prompt, duration_seconds = 2,
         
     | 
| 97 | 
         
            -
                               guidance_scale = 1, steps = 4,
         
     | 
| 98 | 
         
            -
                               seed = 42, randomize_seed = False, 
         
     | 
| 99 | 
         
            -
                               progress=gr.Progress(track_tqdm=True)):
         
     | 
| 100 | 
         
            -
                """
         
     | 
| 101 | 
         
            -
                Generate a video from an input image using the Wan 2.1 I2V model with CausVid LoRA.
         
     | 
| 102 | 
         
            -
                
         
     | 
| 103 | 
         
            -
                This function takes an input image and generates a video animation based on the provided
         
     | 
| 104 | 
         
            -
                prompt and parameters. It uses the Wan 2.1 14B Image-to-Video model with CausVid LoRA
         
     | 
| 105 | 
         
            -
                for fast generation in 4-8 steps.
         
     | 
| 106 | 
         
            -
                
         
     | 
| 107 | 
         
            -
                Args:
         
     | 
| 108 | 
         
            -
                    input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
         
     | 
| 109 | 
         
            -
                    prompt (str): Text prompt describing the desired animation or motion.
         
     | 
| 110 | 
         
            -
                    height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
         
     | 
| 111 | 
         
            -
                    width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
         
     | 
| 112 | 
         
            -
                    negative_prompt (str, optional): Negative prompt to avoid unwanted elements. 
         
     | 
| 113 | 
         
            -
                        Defaults to default_negative_prompt (contains unwanted visual artifacts).
         
     | 
| 114 | 
         
            -
                    duration_seconds (float, optional): Duration of the generated video in seconds.
         
     | 
| 115 | 
         
            -
                        Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
         
     | 
| 116 | 
         
            -
                    guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
         
     | 
| 117 | 
         
            -
                        Defaults to 1.0. Range: 0.0-20.0.
         
     | 
| 118 | 
         
            -
                    steps (int, optional): Number of inference steps. More steps = higher quality but slower.
         
     | 
| 119 | 
         
            -
                        Defaults to 4. Range: 1-30.
         
     | 
| 120 | 
         
            -
                    seed (int, optional): Random seed for reproducible results. Defaults to 42.
         
     | 
| 121 | 
         
            -
                        Range: 0 to MAX_SEED (2147483647).
         
     | 
| 122 | 
         
            -
                    randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
         
     | 
| 123 | 
         
            -
                        Defaults to False.
         
     | 
| 124 | 
         
            -
                    progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
         
     | 
| 125 | 
         
            -
                
         
     | 
| 126 | 
         
            -
                Returns:
         
     | 
| 127 | 
         
            -
                    tuple: A tuple containing:
         
     | 
| 128 | 
         
            -
                        - video_path (str): Path to the generated video file (.mp4)
         
     | 
| 129 | 
         
            -
                        - current_seed (int): The seed used for generation (useful when randomize_seed=True)
         
     | 
| 130 | 
         
            -
                
         
     | 
| 131 | 
         
            -
                Raises:
         
     | 
| 132 | 
         
            -
                    gr.Error: If input_image is None (no image uploaded).
         
     | 
| 133 | 
         
            -
                
         
     | 
| 134 | 
         
            -
                Note:
         
     | 
| 135 | 
         
            -
                    - The function automatically resizes the input image to the target dimensions
         
     | 
| 136 | 
         
            -
                    - Frame count is calculated as duration_seconds * FIXED_FPS (24)
         
     | 
| 137 | 
         
            -
                    - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
         
     | 
| 138 | 
         
            -
                    - The function uses GPU acceleration via the @spaces.GPU decorator
         
     | 
| 139 | 
         
            -
                    - Generation time varies based on steps and duration (see get_duration function)
         
     | 
| 140 | 
         
            -
                """
         
     | 
| 141 | 
         
            -
                if input_image is None:
         
     | 
| 142 | 
         
            -
                    raise gr.Error("Please upload an input image.")
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
             
                target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
         
     | 
| 145 | 
         
             
                target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
         
     | 
| 146 | 
         
            -
             
     | 
| 147 | 
         
             
                num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
         
     | 
| 148 | 
         
            -
                
         
     | 
| 149 | 
         
            -
                current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
         
     | 
| 150 | 
         | 
| 151 | 
         
            -
                 
     | 
| 152 | 
         | 
| 153 | 
         
            -
                 
     | 
| 154 | 
         
            -
                     
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                         
     | 
| 157 | 
         
            -
             
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 160 | 
         | 
| 161 | 
         
             
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
         
     | 
| 162 | 
         
             
                    video_path = tmpfile.name
         
     | 
| 
         @@ -164,14 +109,15 @@ def generate_video(input_image, prompt, height, width, 
     | 
|
| 164 | 
         
             
                return video_path, current_seed
         
     | 
| 165 | 
         | 
| 166 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 167 | 
         
            -
                gr.Markdown("# Fast  
     | 
| 168 | 
         
            -
                gr.Markdown("[ 
     | 
| 
         | 
|
| 169 | 
         
             
                with gr.Row():
         
     | 
| 170 | 
         
             
                    with gr.Column():
         
     | 
| 171 | 
         
            -
                        input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
         
     | 
| 172 | 
         
             
                        prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
         
     | 
| 173 | 
         
             
                        duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
         
     | 
| 174 | 
         
            -
             
     | 
| 175 | 
         
             
                        with gr.Accordion("Advanced Settings", open=False):
         
     | 
| 176 | 
         
             
                            negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
         
     | 
| 177 | 
         
             
                            seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
         
     | 
| 
         @@ -179,9 +125,8 @@ with gr.Blocks() as demo: 
     | 
|
| 179 | 
         
             
                            with gr.Row():
         
     | 
| 180 | 
         
             
                                height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
         
     | 
| 181 | 
         
             
                                width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
         
     | 
| 182 | 
         
            -
                            steps_slider = gr.Slider(minimum=1, maximum= 
     | 
| 183 | 
         
            -
                            guidance_scale_input = gr.Slider(minimum=0.0, maximum= 
     | 
| 184 | 
         
            -
             
     | 
| 185 | 
         
             
                        generate_button = gr.Button("Generate Video", variant="primary")
         
     | 
| 186 | 
         
             
                    with gr.Column():
         
     | 
| 187 | 
         
             
                        video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
         
     | 
| 
         @@ -191,13 +136,13 @@ with gr.Blocks() as demo: 
     | 
|
| 191 | 
         
             
                    inputs=[input_image_component, height_input, width_input],
         
     | 
| 192 | 
         
             
                    outputs=[height_input, width_input]
         
     | 
| 193 | 
         
             
                )
         
     | 
| 194 | 
         
            -
             
     | 
| 195 | 
         
            -
                input_image_component.clear( 
     | 
| 196 | 
         
             
                    fn=handle_image_upload_for_dims_wan,
         
     | 
| 197 | 
         
             
                    inputs=[input_image_component, height_input, width_input],
         
     | 
| 198 | 
         
             
                    outputs=[height_input, width_input]
         
     | 
| 199 | 
         
             
                )
         
     | 
| 200 | 
         
            -
             
     | 
| 201 | 
         
             
                ui_inputs = [
         
     | 
| 202 | 
         
             
                    input_image_component, prompt_input, height_input, width_input,
         
     | 
| 203 | 
         
             
                    negative_prompt_input, duration_seconds_input,
         
     | 
| 
         @@ -208,10 +153,10 @@ with gr.Blocks() as demo: 
     | 
|
| 208 | 
         
             
                gr.Examples(
         
     | 
| 209 | 
         
             
                    examples=[ 
         
     | 
| 210 | 
         
             
                        ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
         
     | 
| 211 | 
         
            -
                        [ 
     | 
| 212 | 
         
             
                    ],
         
     | 
| 213 | 
         
             
                    inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
         
     | 
| 214 | 
         
             
                )
         
     | 
| 215 | 
         | 
| 216 | 
         
             
            if __name__ == "__main__":
         
     | 
| 217 | 
         
            -
                demo.queue().launch( 
     | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from diffusers import AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline, UniPCMultistepScheduler
         
     | 
| 3 | 
         
             
            from diffusers.utils import export_to_video
         
     | 
| 
         | 
|
| 4 | 
         
             
            import gradio as gr
         
     | 
| 5 | 
         
             
            import tempfile
         
     | 
| 6 | 
         
             
            import spaces
         
     | 
| 
         | 
|
| 7 | 
         
             
            import numpy as np
         
     | 
| 8 | 
         
             
            from PIL import Image
         
     | 
| 9 | 
         
             
            import random
         
     | 
| 10 | 
         | 
| 11 | 
         
            +
            MODEL_ID = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
            +
            # Initialize pipelines
         
     | 
| 15 | 
         
            +
            text_to_video_pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
         
     | 
| 16 | 
         
            +
            image_to_video_pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            for pipe in [text_to_video_pipe, image_to_video_pipe]:
         
     | 
| 19 | 
         
            +
                pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
         
     | 
| 20 | 
         
            +
                pipe.to("cuda")
         
     | 
| 21 | 
         | 
| 22 | 
         
            +
            # Constants
         
     | 
| 23 | 
         
             
            MOD_VALUE = 32
         
     | 
| 24 | 
         
            +
            DEFAULT_H_SLIDER_VALUE = 896
         
     | 
| 25 | 
         
             
            DEFAULT_W_SLIDER_VALUE = 896
         
     | 
| 26 | 
         
            +
            NEW_FORMULA_MAX_AREA = 720.0 * 1024
         
     | 
| 27 | 
         
            +
            SLIDER_MIN_H, SLIDER_MAX_H = 256, 1280
         
     | 
| 28 | 
         
            +
            SLIDER_MIN_W, SLIDER_MAX_W = 256, 1280
         
     | 
| 
         | 
|
| 29 | 
         
             
            MAX_SEED = np.iinfo(np.int32).max
         
     | 
| 
         | 
|
| 30 | 
         
             
            FIXED_FPS = 24
         
     | 
| 31 | 
         
            +
            MIN_FRAMES_MODEL = 25
         
     | 
| 32 | 
         
            +
            MAX_FRAMES_MODEL = 193
         
     | 
| 33 | 
         | 
| 34 | 
         
             
            default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
         
     | 
| 35 | 
         
             
            default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
         
     | 
| 36 | 
         | 
| 37 | 
         
            +
            def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, min_slider_h, max_slider_h, min_slider_w, max_slider_w, default_h, default_w):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 38 | 
         
             
                orig_w, orig_h = pil_image.size
         
     | 
| 39 | 
         
             
                if orig_w <= 0 or orig_h <= 0:
         
     | 
| 40 | 
         
             
                    return default_h, default_w
         
     | 
| 
         | 
|
| 41 | 
         
             
                aspect_ratio = orig_h / orig_w
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
             
                calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
         
     | 
| 44 | 
         
             
                calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
         
     | 
| 
         | 
|
| 45 | 
         
             
                calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
         
     | 
| 46 | 
         
             
                calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
             
                new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
         
     | 
| 49 | 
         
             
                new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
             
                return new_h, new_w
         
     | 
| 52 | 
         | 
| 53 | 
         
             
            def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
         
     | 
| 
         | 
|
| 63 | 
         
             
                except Exception as e:
         
     | 
| 64 | 
         
             
                    gr.Warning("Error attempting to calculate new dimensions")
         
     | 
| 65 | 
         
             
                    return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
         
     | 
| 66 | 
         
            +
                    
         
     | 
| 67 | 
         
             
            def get_duration(input_image, prompt, height, width, 
         
     | 
| 68 | 
         
             
                               negative_prompt, duration_seconds,
         
     | 
| 69 | 
         
             
                               guidance_scale, steps,
         
     | 
| 70 | 
         
             
                               seed, randomize_seed, 
         
     | 
| 71 | 
         
             
                               progress):
         
     | 
| 72 | 
         
            +
                if steps > 4 and duration_seconds > 4:
         
     | 
| 73 | 
         
             
                    return 90
         
     | 
| 74 | 
         
            +
                elif steps > 4 or duration_seconds > 4:
         
     | 
| 75 | 
         
             
                    return 75
         
     | 
| 76 | 
         
             
                else:
         
     | 
| 77 | 
         
             
                    return 60
         
     | 
| 78 | 
         | 
| 79 | 
         
             
            @spaces.GPU(duration=get_duration)
         
     | 
| 80 | 
         
            +
            def generate_video(input_image, prompt, height, width, negative_prompt=default_negative_prompt, duration_seconds=2, guidance_scale=1, steps=4, seed=42, randomize_seed=False, progress=gr.Progress(track_tqdm=True)):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 81 | 
         
             
                target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
         
     | 
| 82 | 
         
             
                target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
             
                num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
         
     | 
| 
         | 
|
| 
         | 
|
| 85 | 
         | 
| 86 | 
         
            +
                current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
         
     | 
| 87 | 
         | 
| 88 | 
         
            +
                if input_image is not None:
         
     | 
| 89 | 
         
            +
                    resized_image = input_image.resize((target_w, target_h))
         
     | 
| 90 | 
         
            +
                    with torch.inference_mode():
         
     | 
| 91 | 
         
            +
                        output_frames_list = image_to_video_pipe(
         
     | 
| 92 | 
         
            +
                            image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
         
     | 
| 93 | 
         
            +
                            height=target_h, width=target_w, num_frames=num_frames,
         
     | 
| 94 | 
         
            +
                            guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
         
     | 
| 95 | 
         
            +
                            generator=torch.Generator(device="cuda").manual_seed(current_seed)
         
     | 
| 96 | 
         
            +
                        ).frames[0]
         
     | 
| 97 | 
         
            +
                else:
         
     | 
| 98 | 
         
            +
                    with torch.inference_mode():
         
     | 
| 99 | 
         
            +
                        output_frames_list = text_to_video_pipe(
         
     | 
| 100 | 
         
            +
                            prompt=prompt, negative_prompt=negative_prompt,
         
     | 
| 101 | 
         
            +
                            height=target_h, width=target_w, num_frames=num_frames,
         
     | 
| 102 | 
         
            +
                            guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
         
     | 
| 103 | 
         
            +
                            generator=torch.Generator(device="cuda").manual_seed(current_seed)
         
     | 
| 104 | 
         
            +
                        ).frames[0]
         
     | 
| 105 | 
         | 
| 106 | 
         
             
                with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
         
     | 
| 107 | 
         
             
                    video_path = tmpfile.name
         
     | 
| 
         | 
|
| 109 | 
         
             
                return video_path, current_seed
         
     | 
| 110 | 
         | 
| 111 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 112 | 
         
            +
                gr.Markdown("# Fast Wan 2.1 TI2V 5B Demo")
         
     | 
| 113 | 
         
            +
                gr.Markdown("""This Demo is using [FastWan2.2-TI2V-5B](https://huggingface.co/FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers) which is fine-tuned with Sparse-distill method which allows wan to generate high quality videos in 3-5 steps.""")
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
             
                with gr.Row():
         
     | 
| 116 | 
         
             
                    with gr.Column():
         
     | 
| 117 | 
         
            +
                        input_image_component = gr.Image(type="pil", label="Input Image (optional, auto-resized to target H/W)")
         
     | 
| 118 | 
         
             
                        prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
         
     | 
| 119 | 
         
             
                        duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
             
                        with gr.Accordion("Advanced Settings", open=False):
         
     | 
| 122 | 
         
             
                            negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
         
     | 
| 123 | 
         
             
                            seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
         
     | 
| 
         | 
|
| 125 | 
         
             
                            with gr.Row():
         
     | 
| 126 | 
         
             
                                height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
         
     | 
| 127 | 
         
             
                                width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
         
     | 
| 128 | 
         
            +
                            steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Inference Steps")
         
     | 
| 129 | 
         
            +
                            guidance_scale_input = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=1.0, label="Guidance Scale")
         
     | 
| 
         | 
|
| 130 | 
         
             
                        generate_button = gr.Button("Generate Video", variant="primary")
         
     | 
| 131 | 
         
             
                    with gr.Column():
         
     | 
| 132 | 
         
             
                        video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
         
     | 
| 
         | 
|
| 136 | 
         
             
                    inputs=[input_image_component, height_input, width_input],
         
     | 
| 137 | 
         
             
                    outputs=[height_input, width_input]
         
     | 
| 138 | 
         
             
                )
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                input_image_component.clear(
         
     | 
| 141 | 
         
             
                    fn=handle_image_upload_for_dims_wan,
         
     | 
| 142 | 
         
             
                    inputs=[input_image_component, height_input, width_input],
         
     | 
| 143 | 
         
             
                    outputs=[height_input, width_input]
         
     | 
| 144 | 
         
             
                )
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
             
                ui_inputs = [
         
     | 
| 147 | 
         
             
                    input_image_component, prompt_input, height_input, width_input,
         
     | 
| 148 | 
         
             
                    negative_prompt_input, duration_seconds_input,
         
     | 
| 
         | 
|
| 153 | 
         
             
                gr.Examples(
         
     | 
| 154 | 
         
             
                    examples=[ 
         
     | 
| 155 | 
         
             
                        ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
         
     | 
| 156 | 
         
            +
                        [None, "a penguin playfully dancing in the snow, Antarctica", 1024, 720],
         
     | 
| 157 | 
         
             
                    ],
         
     | 
| 158 | 
         
             
                    inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
         
     | 
| 159 | 
         
             
                )
         
     | 
| 160 | 
         | 
| 161 | 
         
             
            if __name__ == "__main__":
         
     | 
| 162 | 
         
            +
                demo.queue().launch()
         
     |