multimodalart HF Staff commited on
Commit
c8a8fcf
·
verified ·
1 Parent(s): 7deba66

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -0
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import spaces
6
+ import torch
7
+ import gradio as gr
8
+ from huggingface_hub import hf_hub_download
9
+ from comfy import model_management
10
+ from PIL import Image
11
+
12
+ # --- Helper Functions from original script ---
13
+
14
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
15
+ try:
16
+ return obj[index]
17
+ except KeyError:
18
+ return obj["result"][index]
19
+
20
+ def find_path(name: str, path: str = None) -> str:
21
+ if path is None:
22
+ path = os.getcwd()
23
+ if name in os.listdir(path):
24
+ path_name = os.path.join(path, name)
25
+ print(f"{name} found: {path_name}")
26
+ return path_name
27
+ parent_directory = os.path.dirname(path)
28
+ if parent_directory == path:
29
+ return None
30
+ return find_path(name, parent_directory)
31
+
32
+ def add_comfyui_directory_to_sys_path() -> None:
33
+ comfyui_path = find_path("ComfyUI")
34
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
35
+ sys.path.append(comfyui_path)
36
+ print(f"'{comfyui_path}' added to sys.path")
37
+
38
+ def add_extra_model_paths() -> None:
39
+ try:
40
+ from main import load_extra_path_config
41
+ except ImportError:
42
+ from utils.extra_config import load_extra_path_config
43
+ extra_model_paths = find_path("extra_model_paths.yaml")
44
+ if extra_model_paths is not None:
45
+ load_extra_path_config(extra_model_paths)
46
+ else:
47
+ print("Could not find the extra_model_paths config file.")
48
+
49
+ def import_custom_nodes() -> None:
50
+ import asyncio
51
+ import execution
52
+ from nodes import init_extra_nodes
53
+ import server
54
+ loop = asyncio.new_event_loop()
55
+ asyncio.set_event_loop(loop)
56
+ server_instance = server.PromptServer(loop)
57
+ execution.PromptQueue(server_instance)
58
+ init_extra_nodes()
59
+
60
+ # --- Setup and Model Downloads ---
61
+
62
+ add_comfyui_directory_to_sys_path()
63
+ add_extra_model_paths()
64
+ import_custom_nodes()
65
+ from nodes import NODE_CLASS_MAPPINGS
66
+
67
+ print("Downlading models from Hugging Face Hub...")
68
+ # Text Encoder
69
+ hf_hub_download(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
70
+ # UNETs
71
+ hf_hub_download(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
72
+ hf_hub_download(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
73
+ # VAE
74
+ hf_hub_download(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae")
75
+ # CLIP Vision
76
+ hf_hub_download(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision")
77
+ # LoRAs
78
+ hf_hub_download(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras")
79
+ hf_hub_download(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
80
+ print("Downloads complete.")
81
+
82
+ # --- ZeroGPU: Pre-load models and instantiate nodes globally ---
83
+
84
+ # Instantiate Nodes
85
+ cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
86
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
87
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
88
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
89
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
90
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
91
+ clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
92
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
93
+ modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
94
+ pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
95
+ wanfirstlastframetovideo = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
96
+ ksampleradvanced = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
97
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
98
+ createvideo = NODE_CLASS_MAPPINGS["CreateVideo"]()
99
+ savevideo = NODE_CLASS_MAPPINGS["SaveVideo"]()
100
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() # For dynamic resizing
101
+
102
+ # Load Models
103
+ cliploader_38 = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu")
104
+ unetloader_37_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
105
+ unetloader_91_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
106
+ vaeloader_39 = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
107
+ clipvisionloader_49 = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")
108
+
109
+ # Apply LoRAs and Patches
110
+ loraloadermodelonly_94_high = loraloadermodelonly.load_lora_model_only(lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unetloader_91_high_noise, 0))
111
+ loraloadermodelonly_95_low = loraloadermodelonly.load_lora_model_only(lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unetloader_37_low_noise, 0))
112
+ modelsamplingsd3_93_low = modelsamplingsd3.patch(shift=8, model=get_value_at_index(loraloadermodelonly_95_low, 0))
113
+ pathchsageattentionkj_98_low = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(modelsamplingsd3_93_low, 0))
114
+ modelsamplingsd3_79_high = modelsamplingsd3.patch(shift=8, model=get_value_at_index(loraloadermodelonly_94_high, 0))
115
+ pathchsageattentionkj_96_high = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(modelsamplingsd3_79_high, 0))
116
+
117
+ # Pre-load models to GPU
118
+ model_loaders = [cliploader_38, unetloader_37_low_noise, unetloader_91_high_noise, vaeloader_39, clipvisionloader_49, loraloadermodelonly_94_high, loraloadermodelonly_95_low]
119
+ valid_models = [getattr(loader[0], 'patcher', loader[0]) for loader in model_loaders if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)]
120
+ model_management.load_models_gpu(valid_models)
121
+
122
+ # --- Custom Logic for this App ---
123
+
124
+ def calculate_dimensions(image_path):
125
+ with Image.open(image_path) as img:
126
+ width, height = img.size
127
+
128
+ if width == height:
129
+ return 480, 480
130
+
131
+ if width > height:
132
+ new_width = 832
133
+ new_height = int(height * (832 / width))
134
+ else:
135
+ new_height = 832
136
+ new_width = int(width * (832 / height))
137
+
138
+ # Ensure dimensions are multiples of 16
139
+ new_width = (new_width // 16) * 16
140
+ new_height = (new_height // 16) * 16
141
+
142
+ return new_width, new_height
143
+
144
+ # --- Main Generation Function ---
145
+
146
+ @spaces.GPU(duration=120)
147
+ def generate_video(prompt, first_image_path, last_image_path):
148
+ # This function now only handles per-request logic
149
+ with torch.inference_mode():
150
+ # Calculate target dimensions based on the first image
151
+ target_width, target_height = calculate_dimensions(first_image_path)
152
+
153
+ # 1. Load and resize images
154
+ # Since LoadImage returns a tensor, we pass it to the resize node
155
+ loaded_first_image = loadimage.load_image(image=first_image_path)
156
+ resized_first_image = imageresize.execute(
157
+ width=target_width, height=target_height, interpolation="bicubic",
158
+ method="stretch", condition="always", multiple_of=1,
159
+ image=get_value_at_index(loaded_first_image, 0)
160
+ )
161
+
162
+ loaded_last_image = loadimage.load_image(image=last_image_path)
163
+ resized_last_image = imageresize.execute(
164
+ width=target_width, height=target_height, interpolation="bicubic",
165
+ method="stretch", condition="always", multiple_of=1,
166
+ image=get_value_at_index(loaded_last_image, 0)
167
+ )
168
+
169
+ # 2. Encode text and images
170
+ cliptextencode_6 = cliptextencode.encode(text=prompt, clip=get_value_at_index(cliploader_38, 0))
171
+ cliptextencode_7_negative = cliptextencode.encode(
172
+ text="low quality, worst quality, jpeg artifacts, ugly, deformed, blurry",
173
+ clip=get_value_at_index(cliploader_38, 0),
174
+ )
175
+ clipvisionencode_51 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(resized_first_image, 0))
176
+ clipvisionencode_87 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(resized_last_image, 0))
177
+
178
+ # 3. Prepare latents for video generation
179
+ wanfirstlastframetovideo_83 = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
180
+ width=target_width, height=target_height, length=33, batch_size=1,
181
+ positive=get_value_at_index(cliptextencode_6, 0),
182
+ negative=get_value_at_index(cliptextencode_7_negative, 0),
183
+ vae=get_value_at_index(vaeloader_39, 0),
184
+ clip_vision_start_image=get_value_at_index(clipvisionencode_51, 0),
185
+ clip_vision_end_image=get_value_at_index(clipvisionencode_87, 0),
186
+ start_image=get_value_at_index(resized_first_image, 0),
187
+ end_image=get_value_at_index(resized_last_image, 0),
188
+ )
189
+
190
+ # 4. KSampler pipeline
191
+ ksampleradvanced_101 = ksampleradvanced.sample(
192
+ add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
193
+ sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
194
+ return_with_leftover_noise="enable", model=get_value_at_index(pathchsageattentionkj_96_high, 0),
195
+ positive=get_value_at_index(wanfirstlastframetovideo_83, 0),
196
+ negative=get_value_at_index(wanfirstlastframetovideo_83, 1),
197
+ latent_image=get_value_at_index(wanfirstlastframetovideo_83, 2),
198
+ )
199
+ ksampleradvanced_102 = ksampleradvanced.sample(
200
+ add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
201
+ sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
202
+ return_with_leftover_noise="disable", model=get_value_at_index(pathchsageattentionkj_98_low, 0),
203
+ positive=get_value_at_index(wanfirstlastframetovideo_83, 0),
204
+ negative=get_value_at_index(wanfirstlastframetovideo_83, 1),
205
+ latent_image=get_value_at_index(ksampleradvanced_101, 0),
206
+ )
207
+
208
+ # 5. Decode and save video
209
+ vaedecode_8 = vaedecode.decode(samples=get_value_at_index(ksampleradvanced_102, 0), vae=get_value_at_index(vaeloader_39, 0))
210
+ createvideo_104 = createvideo.create_video(fps=16, images=get_value_at_index(vaedecode_8, 0))
211
+ savevideo_103 = savevideo.save_video(filename_prefix="ComfyUI_Video", format="mp4", codec="libx264", video=get_value_at_index(createvideo_104, 0))
212
+
213
+ # Return the path to the saved video
214
+ video_filename = savevideo_103['ui']['videos'][0]['filename']
215
+ return f"output/{video_filename}"
216
+
217
+ # --- Gradio Interface ---
218
+
219
+ with gr.Blocks() as app:
220
+ gr.Markdown("# Wan 2.2 First/Last Frame to Video")
221
+ gr.Markdown("Provide a starting image, an ending image, and a text prompt to generate a video transitioning between them.")
222
+
223
+ with gr.Row():
224
+ with gr.Column(scale=1):
225
+ prompt_input = gr.Textbox(label="Prompt", value="the guy turns")
226
+ first_image = gr.Image(label="First Frame", type="filepath")
227
+ last_image = gr.Image(label="Last Frame", type="filepath")
228
+ generate_btn = gr.Button("Generate Video")
229
+ with gr.Column(scale=2):
230
+ output_video = gr.Video(label="Generated Video")
231
+
232
+ generate_btn.click(
233
+ fn=generate_video,
234
+ inputs=[prompt_input, first_image, last_image],
235
+ outputs=[output_video]
236
+ )
237
+
238
+ gr.Examples(
239
+ examples=[
240
+ ["a beautiful woman, cinematic", "examples/start.png", "examples/end.png"]
241
+ ],
242
+ inputs=[prompt_input, first_image, last_image]
243
+ )
244
+
245
+ if __name__ == "__main__":
246
+ # Create example images if they don't exist
247
+ if not os.path.exists("examples"):
248
+ os.makedirs("examples")
249
+ if not os.path.exists("examples/start.png"):
250
+ Image.new('RGB', (512, 512), color = 'red').save('examples/start.png')
251
+ if not os.path.exists("examples/end.png"):
252
+ Image.new('RGB', (512, 512), color = 'blue').save('examples/end.png')
253
+
254
+ app.launch()