KingNish commited on
Commit
c19ca3d
·
1 Parent(s): c388e94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -110
app.py CHANGED
@@ -1,68 +1,53 @@
1
  import torch
2
- from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
  from diffusers.utils import export_to_video
4
- from transformers import CLIPVisionModel
5
  import gradio as gr
6
  import tempfile
7
  import spaces
8
- from huggingface_hub import hf_hub_download
9
  import numpy as np
10
  from PIL import Image
11
  import random
12
 
13
- MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
14
- LORA_REPO_ID = "Kijai/WanVideo_comfy"
15
- LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
16
-
17
- image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
18
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
19
- pipe = WanImageToVideoPipeline.from_pretrained(
20
- MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
21
- )
22
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
23
- pipe.to("cuda")
24
 
25
- causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
26
- pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
27
- pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
28
- pipe.fuse_lora()
 
 
 
29
 
 
30
  MOD_VALUE = 32
31
- DEFAULT_H_SLIDER_VALUE = 512
32
  DEFAULT_W_SLIDER_VALUE = 896
33
- NEW_FORMULA_MAX_AREA = 480.0 * 832.0
34
-
35
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
36
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
37
  MAX_SEED = np.iinfo(np.int32).max
38
-
39
  FIXED_FPS = 24
40
- MIN_FRAMES_MODEL = 8
41
- MAX_FRAMES_MODEL = 81
42
 
43
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
44
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
45
 
46
-
47
- def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
48
- min_slider_h, max_slider_h,
49
- min_slider_w, max_slider_w,
50
- default_h, default_w):
51
  orig_w, orig_h = pil_image.size
52
  if orig_w <= 0 or orig_h <= 0:
53
  return default_h, default_w
54
-
55
  aspect_ratio = orig_h / orig_w
56
-
57
  calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
58
  calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
59
-
60
  calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
61
  calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
62
-
63
  new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
64
  new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
65
-
66
  return new_h, new_w
67
 
68
  def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
@@ -78,85 +63,45 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_
78
  except Exception as e:
79
  gr.Warning("Error attempting to calculate new dimensions")
80
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
81
-
82
  def get_duration(input_image, prompt, height, width,
83
  negative_prompt, duration_seconds,
84
  guidance_scale, steps,
85
  seed, randomize_seed,
86
  progress):
87
- if steps > 4 and duration_seconds > 2:
88
  return 90
89
- elif steps > 4 or duration_seconds > 2:
90
  return 75
91
  else:
92
  return 60
93
 
94
  @spaces.GPU(duration=get_duration)
95
- def generate_video(input_image, prompt, height, width,
96
- negative_prompt=default_negative_prompt, duration_seconds = 2,
97
- guidance_scale = 1, steps = 4,
98
- seed = 42, randomize_seed = False,
99
- progress=gr.Progress(track_tqdm=True)):
100
- """
101
- Generate a video from an input image using the Wan 2.1 I2V model with CausVid LoRA.
102
-
103
- This function takes an input image and generates a video animation based on the provided
104
- prompt and parameters. It uses the Wan 2.1 14B Image-to-Video model with CausVid LoRA
105
- for fast generation in 4-8 steps.
106
-
107
- Args:
108
- input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
109
- prompt (str): Text prompt describing the desired animation or motion.
110
- height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
111
- width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
112
- negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
113
- Defaults to default_negative_prompt (contains unwanted visual artifacts).
114
- duration_seconds (float, optional): Duration of the generated video in seconds.
115
- Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
116
- guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
117
- Defaults to 1.0. Range: 0.0-20.0.
118
- steps (int, optional): Number of inference steps. More steps = higher quality but slower.
119
- Defaults to 4. Range: 1-30.
120
- seed (int, optional): Random seed for reproducible results. Defaults to 42.
121
- Range: 0 to MAX_SEED (2147483647).
122
- randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
123
- Defaults to False.
124
- progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
125
-
126
- Returns:
127
- tuple: A tuple containing:
128
- - video_path (str): Path to the generated video file (.mp4)
129
- - current_seed (int): The seed used for generation (useful when randomize_seed=True)
130
-
131
- Raises:
132
- gr.Error: If input_image is None (no image uploaded).
133
-
134
- Note:
135
- - The function automatically resizes the input image to the target dimensions
136
- - Frame count is calculated as duration_seconds * FIXED_FPS (24)
137
- - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
138
- - The function uses GPU acceleration via the @spaces.GPU decorator
139
- - Generation time varies based on steps and duration (see get_duration function)
140
- """
141
- if input_image is None:
142
- raise gr.Error("Please upload an input image.")
143
-
144
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
145
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
146
-
147
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
148
-
149
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
150
 
151
- resized_image = input_image.resize((target_w, target_h))
152
 
153
- with torch.inference_mode():
154
- output_frames_list = pipe(
155
- image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
156
- height=target_h, width=target_w, num_frames=num_frames,
157
- guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
158
- generator=torch.Generator(device="cuda").manual_seed(current_seed)
159
- ).frames[0]
 
 
 
 
 
 
 
 
 
 
160
 
161
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
162
  video_path = tmpfile.name
@@ -164,14 +109,15 @@ def generate_video(input_image, prompt, height, width,
164
  return video_path, current_seed
165
 
166
  with gr.Blocks() as demo:
167
- gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA")
168
- gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
 
169
  with gr.Row():
170
  with gr.Column():
171
- input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
172
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
173
  duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
174
-
175
  with gr.Accordion("Advanced Settings", open=False):
176
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
177
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
@@ -179,9 +125,8 @@ with gr.Blocks() as demo:
179
  with gr.Row():
180
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
181
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
182
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
183
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
184
-
185
  generate_button = gr.Button("Generate Video", variant="primary")
186
  with gr.Column():
187
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
@@ -191,13 +136,13 @@ with gr.Blocks() as demo:
191
  inputs=[input_image_component, height_input, width_input],
192
  outputs=[height_input, width_input]
193
  )
194
-
195
- input_image_component.clear(
196
  fn=handle_image_upload_for_dims_wan,
197
  inputs=[input_image_component, height_input, width_input],
198
  outputs=[height_input, width_input]
199
  )
200
-
201
  ui_inputs = [
202
  input_image_component, prompt_input, height_input, width_input,
203
  negative_prompt_input, duration_seconds_input,
@@ -208,10 +153,10 @@ with gr.Blocks() as demo:
208
  gr.Examples(
209
  examples=[
210
  ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
211
- ["forg.jpg", "the frog jumps around", 448, 832],
212
  ],
213
  inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
214
  )
215
 
216
  if __name__ == "__main__":
217
- demo.queue().launch(mcp_server=True)
 
1
  import torch
2
+ from diffusers import AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline, UniPCMultistepScheduler
3
  from diffusers.utils import export_to_video
 
4
  import gradio as gr
5
  import tempfile
6
  import spaces
 
7
  import numpy as np
8
  from PIL import Image
9
  import random
10
 
11
+ MODEL_ID = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers"
 
 
 
 
12
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
 
 
 
 
 
13
 
14
+ # Initialize pipelines
15
+ text_to_video_pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
16
+ image_to_video_pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
17
+
18
+ for pipe in [text_to_video_pipe, image_to_video_pipe]:
19
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
20
+ pipe.to("cuda")
21
 
22
+ # Constants
23
  MOD_VALUE = 32
24
+ DEFAULT_H_SLIDER_VALUE = 896
25
  DEFAULT_W_SLIDER_VALUE = 896
26
+ NEW_FORMULA_MAX_AREA = 720.0 * 1024
27
+ SLIDER_MIN_H, SLIDER_MAX_H = 256, 1280
28
+ SLIDER_MIN_W, SLIDER_MAX_W = 256, 1280
 
29
  MAX_SEED = np.iinfo(np.int32).max
 
30
  FIXED_FPS = 24
31
+ MIN_FRAMES_MODEL = 25
32
+ MAX_FRAMES_MODEL = 193
33
 
34
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
35
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
36
 
37
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, min_slider_h, max_slider_h, min_slider_w, max_slider_w, default_h, default_w):
 
 
 
 
38
  orig_w, orig_h = pil_image.size
39
  if orig_w <= 0 or orig_h <= 0:
40
  return default_h, default_w
 
41
  aspect_ratio = orig_h / orig_w
42
+
43
  calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
44
  calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
 
45
  calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
46
  calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
47
+
48
  new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
49
  new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
50
+
51
  return new_h, new_w
52
 
53
  def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
 
63
  except Exception as e:
64
  gr.Warning("Error attempting to calculate new dimensions")
65
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
66
+
67
  def get_duration(input_image, prompt, height, width,
68
  negative_prompt, duration_seconds,
69
  guidance_scale, steps,
70
  seed, randomize_seed,
71
  progress):
72
+ if steps > 4 and duration_seconds > 4:
73
  return 90
74
+ elif steps > 4 or duration_seconds > 4:
75
  return 75
76
  else:
77
  return 60
78
 
79
  @spaces.GPU(duration=get_duration)
80
+ def generate_video(input_image, prompt, height, width, negative_prompt=default_negative_prompt, duration_seconds=2, guidance_scale=1, steps=4, seed=42, randomize_seed=False, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
82
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
83
+
84
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
 
 
85
 
86
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
87
 
88
+ if input_image is not None:
89
+ resized_image = input_image.resize((target_w, target_h))
90
+ with torch.inference_mode():
91
+ output_frames_list = image_to_video_pipe(
92
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
93
+ height=target_h, width=target_w, num_frames=num_frames,
94
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
95
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
96
+ ).frames[0]
97
+ else:
98
+ with torch.inference_mode():
99
+ output_frames_list = text_to_video_pipe(
100
+ prompt=prompt, negative_prompt=negative_prompt,
101
+ height=target_h, width=target_w, num_frames=num_frames,
102
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
103
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
104
+ ).frames[0]
105
 
106
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
107
  video_path = tmpfile.name
 
109
  return video_path, current_seed
110
 
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("# Fast Wan 2.1 TI2V 5B Demo")
113
+ gr.Markdown("""This Demo is using [FastWan2.2-TI2V-5B](https://huggingface.co/FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers) which is fine-tuned with Sparse-distill method which allows wan to generate high quality videos in 3-5 steps.""")
114
+
115
  with gr.Row():
116
  with gr.Column():
117
+ input_image_component = gr.Image(type="pil", label="Input Image (optional, auto-resized to target H/W)")
118
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
119
  duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
120
+
121
  with gr.Accordion("Advanced Settings", open=False):
122
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
123
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
 
125
  with gr.Row():
126
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
127
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
128
+ steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Inference Steps")
129
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=1.0, label="Guidance Scale")
 
130
  generate_button = gr.Button("Generate Video", variant="primary")
131
  with gr.Column():
132
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
 
136
  inputs=[input_image_component, height_input, width_input],
137
  outputs=[height_input, width_input]
138
  )
139
+
140
+ input_image_component.clear(
141
  fn=handle_image_upload_for_dims_wan,
142
  inputs=[input_image_component, height_input, width_input],
143
  outputs=[height_input, width_input]
144
  )
145
+
146
  ui_inputs = [
147
  input_image_component, prompt_input, height_input, width_input,
148
  negative_prompt_input, duration_seconds_input,
 
153
  gr.Examples(
154
  examples=[
155
  ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
156
+ [None, "a penguin playfully dancing in the snow, Antarctica", 1024, 720],
157
  ],
158
  inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
159
  )
160
 
161
  if __name__ == "__main__":
162
+ demo.queue().launch()