multimodalart HF Staff commited on
Commit
e477e53
·
verified ·
1 Parent(s): c6f4804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -19,8 +19,11 @@ from optimization import optimize_pipeline_
19
 
20
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
21
 
22
- LANDSCAPE_WIDTH = 832
23
- LANDSCAPE_HEIGHT = 480
 
 
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
 
26
  FIXED_FPS = 16
@@ -50,11 +53,14 @@ for i in range(3):
50
  torch.cuda.synchronize()
51
  torch.cuda.empty_cache()
52
 
 
 
 
53
  optimize_pipeline_(pipe,
54
- image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)),
55
  prompt='prompt',
56
- height=LANDSCAPE_HEIGHT,
57
- width=LANDSCAPE_WIDTH,
58
  num_frames=MAX_FRAMES_MODEL,
59
  )
60
 
@@ -62,28 +68,51 @@ optimize_pipeline_(pipe,
62
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
63
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
64
 
65
-
66
  def resize_image(image: Image.Image) -> Image.Image:
67
- if image.height > image.width:
68
- transposed = image.transpose(Image.Transpose.ROTATE_90)
69
- resized = resize_image_landscape(transposed)
70
- return resized.transpose(Image.Transpose.ROTATE_270)
71
- return resize_image_landscape(image)
72
 
 
 
 
73
 
74
- def resize_image_landscape(image: Image.Image) -> Image.Image:
75
- target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
76
- width, height = image.size
77
- in_aspect = width / height
78
- if in_aspect > target_aspect:
79
- new_width = round(height * target_aspect)
80
- left = (width - new_width) // 2
81
- image = image.crop((left, 0, left + new_width, height))
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
- new_height = round(width / target_aspect)
84
- top = (height - new_height) // 2
85
- image = image.crop((0, top, width, top + new_height))
86
- return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def get_duration(
89
  input_image,
@@ -147,7 +176,6 @@ def generate_video(
147
  gr.Error: If input_image is None (no image uploaded).
148
 
149
  Note:
150
- - The function automatically resizes the input image to the target dimensions
151
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
152
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
153
  - The function uses GPU acceleration via the @spaces.GPU decorator
@@ -185,7 +213,7 @@ with gr.Blocks() as demo:
185
  gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
186
  with gr.Row():
187
  with gr.Column():
188
- input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
189
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
190
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
191
 
 
19
 
20
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
21
 
22
+ MAX_DIM = 832
23
+ MIN_DIM = 480
24
+ SQUARE_DIM = 640
25
+ MULTIPLE_OF = 16
26
+
27
  MAX_SEED = np.iinfo(np.int32).max
28
 
29
  FIXED_FPS = 16
 
53
  torch.cuda.synchronize()
54
  torch.cuda.empty_cache()
55
 
56
+ OPTIMIZE_WIDTH = 832
57
+ OPTIMIZE_HEIGHT = 624
58
+
59
  optimize_pipeline_(pipe,
60
+ image=Image.new('RGB', (OPTIMIZE_WIDTH, OPTIMIZE_HEIGHT)),
61
  prompt='prompt',
62
+ height=OPTIMIZE_HEIGHT,
63
+ width=OPTIMIZE_WIDTH,
64
  num_frames=MAX_FRAMES_MODEL,
65
  )
66
 
 
68
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
69
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
70
 
 
71
  def resize_image(image: Image.Image) -> Image.Image:
72
+ """
73
+ Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
74
+ """
75
+ width, height = image.size
 
76
 
77
+ # Handle square case
78
+ if width == height:
79
+ return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
80
 
81
+ aspect_ratio = width / height
82
+
83
+ MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
84
+ MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
85
+
86
+ image_to_resize = image
87
+
88
+ if aspect_ratio > MAX_ASPECT_RATIO:
89
+ # Very wide image -> crop width to fit 832x480 aspect ratio
90
+ target_w, target_h = MAX_DIM, MIN_DIM
91
+ crop_width = int(round(height * MAX_ASPECT_RATIO))
92
+ left = (width - crop_width) // 2
93
+ image_to_resize = image.crop((left, 0, left + crop_width, height))
94
+ elif aspect_ratio < MIN_ASPECT_RATIO:
95
+ # Very tall image -> crop height to fit 480x832 aspect ratio
96
+ target_w, target_h = MIN_DIM, MAX_DIM
97
+ crop_height = int(round(width / MIN_ASPECT_RATIO))
98
+ top = (height - crop_height) // 2
99
+ image_to_resize = image.crop((0, top, width, top + crop_height))
100
  else:
101
+ if width > height: # Landscape
102
+ target_w = MAX_DIM
103
+ target_h = int(round(target_w / aspect_ratio))
104
+ else: # Portrait
105
+ target_h = MAX_DIM
106
+ target_w = int(round(target_h * aspect_ratio))
107
+
108
+ final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
109
+ final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
110
+
111
+ final_w = max(MIN_DIM, min(MAX_DIM, final_w))
112
+ final_h = max(MIN_DIM, min(MAX_DIM, final_h))
113
+
114
+ return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
115
+
116
 
117
  def get_duration(
118
  input_image,
 
176
  gr.Error: If input_image is None (no image uploaded).
177
 
178
  Note:
 
179
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
180
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
181
  - The function uses GPU acceleration via the @spaces.GPU decorator
 
213
  gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
214
  with gr.Row():
215
  with gr.Column():
216
+ input_image_component = gr.Image(type="pil", label="Input Image")
217
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
218
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
219