multimodalart HF Staff commited on
Commit
941a8cc
·
verified ·
1 Parent(s): d24ffa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -27
app.py CHANGED
@@ -1,44 +1,40 @@
1
  import os
2
-
3
- if os.getcwd() != '/home/user/app':
4
- os.chdir('/home/user/app')
5
-
6
  import sys
7
- import spaces
8
  import subprocess
9
  import asyncio
 
10
  from typing import Sequence, Mapping, Any, Union
11
 
 
12
  print("Importing ComfyUI's main.py for setup...")
13
  import main
14
  print("ComfyUI main imported.")
15
 
16
-
17
  import torch
18
  import gradio as gr
19
  from huggingface_hub import hf_hub_download
20
  from comfy import model_management
 
21
  from PIL import Image
22
  import random
23
- import nodes # Import nodes after main has set everything up
24
-
25
 
26
- # --- Manually trigger the node initialization ---
27
- # This step is normally done inside main.start_comfyui(), but we do it here.
28
- # It loads all built-in, extra, and custom nodes into the NODE_CLASS_MAPPINGS.
29
  print("Initializing ComfyUI nodes...")
30
  loop = asyncio.new_event_loop()
31
  asyncio.set_event_loop(loop)
32
  loop.run_until_complete(nodes.init_extra_nodes())
33
  print("Nodes initialized.")
34
 
35
- # --- Helper function from the original script ---
36
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
37
  try:
38
  return obj[index]
39
  except KeyError:
40
  return obj["result"][index]
41
 
 
42
  # --- Model Downloads ---
43
  print("Downloading models from Hugging Face Hub...")
44
  hf_hub_download(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
@@ -52,7 +48,6 @@ print("Downloads complete.")
52
 
53
 
54
  # --- ZeroGPU: Pre-load models and instantiate nodes globally ---
55
- # This part will now work because NODE_CLASS_MAPPINGS is correctly populated.
56
  cliploader = nodes.NODE_CLASS_MAPPINGS["CLIPLoader"]()
57
  cliptextencode = nodes.NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
58
  unetloader = nodes.NODE_CLASS_MAPPINGS["UNETLoader"]()
@@ -68,7 +63,6 @@ ksampleradvanced = nodes.NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
68
  vaedecode = nodes.NODE_CLASS_MAPPINGS["VAEDecode"]()
69
  createvideo = nodes.NODE_CLASS_MAPPINGS["CreateVideo"]()
70
  savevideo = nodes.NODE_CLASS_MAPPINGS["SaveVideo"]()
71
- imageresize = nodes.NODE_CLASS_MAPPINGS["ImageResize+"]()
72
 
73
  cliploader_38 = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu")
74
  unetloader_37_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
@@ -88,8 +82,7 @@ valid_models = [getattr(loader[0], 'patcher', loader[0]) for loader in model_loa
88
  model_management.load_models_gpu(valid_models)
89
 
90
  # --- App Logic ---
91
- def calculate_dimensions(image_path):
92
- with Image.open(image_path) as img: width, height = img.size
93
  if width == height: return 480, 480
94
  if width > height: new_width, new_height = 832, int(height * (832 / width))
95
  else: new_height, new_width = 832, int(width * (832 / height))
@@ -97,23 +90,46 @@ def calculate_dimensions(image_path):
97
 
98
  @spaces.GPU(duration=120)
99
  def generate_video(prompt, first_image_path, last_image_path, duration_seconds):
 
 
 
 
100
  with torch.inference_mode():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  FPS, MAX_FRAMES = 16, 81
102
  length_in_frames = max(1, min(int(duration_seconds * FPS), MAX_FRAMES))
103
  print(f"Requested duration: {duration_seconds}s. Calculated frames: {length_in_frames}")
104
- target_width, target_height = calculate_dimensions(first_image_path)
105
-
106
- loaded_first_image = loadimage.load_image(image=first_image_path)
107
- resized_first_image = imageresize.execute(width=target_width, height=target_height, interpolation="bicubic", method="stretch", image=get_value_at_index(loaded_first_image, 0))
108
- loaded_last_image = loadimage.load_image(image=last_image_path)
109
- resized_last_image = imageresize.execute(width=target_width, height=target_height, interpolation="bicubic", method="stretch", image=get_value_at_index(loaded_last_image, 0))
110
-
111
  cliptextencode_6 = cliptextencode.encode(text=prompt, clip=get_value_at_index(cliploader_38, 0))
112
  cliptextencode_7_negative = cliptextencode.encode(text="low quality, worst quality, jpeg artifacts, ugly, deformed, blurry", clip=get_value_at_index(cliploader_38, 0))
113
- clipvisionencode_51 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(resized_first_image, 0))
114
- clipvisionencode_87 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(resized_last_image, 0))
115
 
116
- wanfirstlastframetovideo_83 = wanfirstlastframetovideo.EXECUTE_NORMALIZED(width=target_width, height=target_height, length=length_in_frames, batch_size=1, positive=get_value_at_index(cliptextencode_6, 0), negative=get_value_at_index(cliptextencode_7_negative, 0), vae=get_value_at_index(vaeloader_39, 0), clip_vision_start_image=get_value_at_index(clipvisionencode_51, 0), clip_vision_end_image=get_value_at_index(clipvisionencode_87, 0), start_image=get_value_at_index(resized_first_image, 0), end_image=get_value_at_index(resized_last_image, 0))
117
 
118
  ksampler_positive = get_value_at_index(wanfirstlastframetovideo_83, 0)
119
  ksampler_negative = get_value_at_index(wanfirstlastframetovideo_83, 1)
@@ -128,7 +144,7 @@ def generate_video(prompt, first_image_path, last_image_path, duration_seconds):
128
 
129
  return f"output/{savevideo_103['ui']['videos'][0]['filename']}"
130
 
131
- # --- Gradio Interface (no changes needed) ---
132
  with gr.Blocks() as app:
133
  gr.Markdown("# Wan 2.2 First/Last Frame to Video")
134
  gr.Markdown("Provide a starting image, an ending image, a text prompt, and a desired duration to generate a video transitioning between them.")
@@ -149,4 +165,7 @@ if __name__ == "__main__":
149
  if not os.path.exists("examples"): os.makedirs("examples")
150
  if not os.path.exists("examples/start.png"): Image.new('RGB', (512, 512), color='red').save('examples/start.png')
151
  if not os.path.exists("examples/end.png"): Image.new('RGB', (512, 512), color='blue').save('examples/end.png')
 
 
 
152
  app.launch()
 
1
  import os
 
 
 
 
2
  import sys
 
3
  import subprocess
4
  import asyncio
5
+ import uuid
6
  from typing import Sequence, Mapping, Any, Union
7
 
8
+ # --- 2. Let ComfyUI's main.py handle all initial setup ---
9
  print("Importing ComfyUI's main.py for setup...")
10
  import main
11
  print("ComfyUI main imported.")
12
 
13
+ # --- 3. Now we can import the rest of the necessary modules ---
14
  import torch
15
  import gradio as gr
16
  from huggingface_hub import hf_hub_download
17
  from comfy import model_management
18
+ import spaces
19
  from PIL import Image
20
  import random
21
+ import nodes
 
22
 
23
+ # --- 4. Manually trigger the node initialization ---
 
 
24
  print("Initializing ComfyUI nodes...")
25
  loop = asyncio.new_event_loop()
26
  asyncio.set_event_loop(loop)
27
  loop.run_until_complete(nodes.init_extra_nodes())
28
  print("Nodes initialized.")
29
 
30
+ # --- Helper function ---
31
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
32
  try:
33
  return obj[index]
34
  except KeyError:
35
  return obj["result"][index]
36
 
37
+
38
  # --- Model Downloads ---
39
  print("Downloading models from Hugging Face Hub...")
40
  hf_hub_download(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
 
48
 
49
 
50
  # --- ZeroGPU: Pre-load models and instantiate nodes globally ---
 
51
  cliploader = nodes.NODE_CLASS_MAPPINGS["CLIPLoader"]()
52
  cliptextencode = nodes.NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
53
  unetloader = nodes.NODE_CLASS_MAPPINGS["UNETLoader"]()
 
63
  vaedecode = nodes.NODE_CLASS_MAPPINGS["VAEDecode"]()
64
  createvideo = nodes.NODE_CLASS_MAPPINGS["CreateVideo"]()
65
  savevideo = nodes.NODE_CLASS_MAPPINGS["SaveVideo"]()
 
66
 
67
  cliploader_38 = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu")
68
  unetloader_37_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
 
82
  model_management.load_models_gpu(valid_models)
83
 
84
  # --- App Logic ---
85
+ def calculate_dimensions(width, height):
 
86
  if width == height: return 480, 480
87
  if width > height: new_width, new_height = 832, int(height * (832 / width))
88
  else: new_height, new_width = 832, int(width * (832 / height))
 
90
 
91
  @spaces.GPU(duration=120)
92
  def generate_video(prompt, first_image_path, last_image_path, duration_seconds):
93
+ # Create a temporary directory for resized images
94
+ temp_dir = f"temp_resized_{uuid.uuid4().hex}"
95
+ os.makedirs(temp_dir, exist_ok=True)
96
+
97
  with torch.inference_mode():
98
+ # --- Python Image Preprocessing using Pillow ---
99
+ print("Preprocessing images with Pillow...")
100
+ with Image.open(first_image_path) as img:
101
+ orig_width, orig_height = img.size
102
+
103
+ target_width, target_height = calculate_dimensions(orig_width, orig_height)
104
+
105
+ # Resize first image
106
+ with Image.open(first_image_path) as img:
107
+ img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
108
+ resized_first_path = os.path.join(temp_dir, "first_frame_resized.png")
109
+ img_resized.save(resized_first_path)
110
+
111
+ # Resize second image to match the target dimensions
112
+ with Image.open(last_image_path) as img:
113
+ img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
114
+ resized_last_path = os.path.join(temp_dir, "last_frame_resized.png")
115
+ img_resized.save(resized_last_path)
116
+ print(f"Images resized to {target_width}x{target_height} and saved temporarily.")
117
+ # --- End Preprocessing ---
118
+
119
  FPS, MAX_FRAMES = 16, 81
120
  length_in_frames = max(1, min(int(duration_seconds * FPS), MAX_FRAMES))
121
  print(f"Requested duration: {duration_seconds}s. Calculated frames: {length_in_frames}")
122
+
123
+ # Load the pre-processed images into ComfyUI
124
+ loaded_first_image = loadimage.load_image(image=os.path.basename(resized_first_path))
125
+ loaded_last_image = loadimage.load_image(image=os.path.basename(resized_last_path))
126
+
 
 
127
  cliptextencode_6 = cliptextencode.encode(text=prompt, clip=get_value_at_index(cliploader_38, 0))
128
  cliptextencode_7_negative = cliptextencode.encode(text="low quality, worst quality, jpeg artifacts, ugly, deformed, blurry", clip=get_value_at_index(cliploader_38, 0))
129
+ clipvisionencode_51 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(loaded_first_image, 0))
130
+ clipvisionencode_87 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(loaded_last_image, 0))
131
 
132
+ wanfirstlastframetovideo_83 = wanfirstlastframetovideo.EXECUTE_NORMALIZED(width=target_width, height=target_height, length=length_in_frames, batch_size=1, positive=get_value_at_index(cliptextencode_6, 0), negative=get_value_at_index(cliptextencode_7_negative, 0), vae=get_value_at_index(vaeloader_39, 0), clip_vision_start_image=get_value_at_index(clipvisionencode_51, 0), clip_vision_end_image=get_value_at_index(clipvisionencode_87, 0), start_image=get_value_at_index(loaded_first_image, 0), end_image=get_value_at_index(loaded_last_image, 0))
133
 
134
  ksampler_positive = get_value_at_index(wanfirstlastframetovideo_83, 0)
135
  ksampler_negative = get_value_at_index(wanfirstlastframetovideo_83, 1)
 
144
 
145
  return f"output/{savevideo_103['ui']['videos'][0]['filename']}"
146
 
147
+ # --- Gradio Interface ---
148
  with gr.Blocks() as app:
149
  gr.Markdown("# Wan 2.2 First/Last Frame to Video")
150
  gr.Markdown("Provide a starting image, an ending image, a text prompt, and a desired duration to generate a video transitioning between them.")
 
165
  if not os.path.exists("examples"): os.makedirs("examples")
166
  if not os.path.exists("examples/start.png"): Image.new('RGB', (512, 512), color='red').save('examples/start.png')
167
  if not os.path.exists("examples/end.png"): Image.new('RGB', (512, 512), color='blue').save('examples/end.png')
168
+ # Set the input directory for LoadImage to find the temp files
169
+ import folder_paths
170
+ folder_paths.add_model_folder_path("input", "temp_resized")
171
  app.launch()