multimodalart's picture
Update app.py
e584055 verified
raw
history blame
12.4 kB
import os
import shutil
import sys
import subprocess
import asyncio
import uuid
from typing import Sequence, Mapping, Any, Union
from huggingface_hub import hf_hub_download
import spaces
def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
os.makedirs(local_dir, exist_ok=True)
base_filename = os.path.basename(filename)
target_path = os.path.join(local_dir, base_filename)
if os.path.exists(target_path) or os.path.islink(target_path):
os.remove(target_path)
os.symlink(downloaded_path, target_path)
return target_path
# --- Model Downloads ---
print("Downloading models from Hugging Face Hub...")
text_encoder_repo = hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
print(text_encoder_repo)
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision")
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras")
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
print("Downloads complete.")
# --- 2. Let ComfyUI's main.py handle all initial setup ---
print("Importing ComfyUI's main.py for setup...")
import main
print("ComfyUI main imported.")
# --- 3. Now we can import the rest of the necessary modules ---
import torch
import gradio as gr
from comfy import model_management
from PIL import Image
import random
import nodes
# --- 4. Manually trigger the node initialization ---
print("Initializing ComfyUI nodes...")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(nodes.init_extra_nodes())
print("Nodes initialized.")
# --- Helper function ---
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# --- ZeroGPU: Pre-load models and instantiate nodes globally ---
cliploader = nodes.NODE_CLASS_MAPPINGS["CLIPLoader"]()
cliptextencode = nodes.NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
unetloader = nodes.NODE_CLASS_MAPPINGS["UNETLoader"]()
vaeloader = nodes.NODE_CLASS_MAPPINGS["VAELoader"]()
clipvisionloader = nodes.NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
loadimage = nodes.NODE_CLASS_MAPPINGS["LoadImage"]()
clipvisionencode = nodes.NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
loraloadermodelonly = nodes.NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
modelsamplingsd3 = nodes.NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
pathchsageattentionkj = nodes.NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
wanfirstlastframetovideo = nodes.NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
ksampleradvanced = nodes.NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
vaedecode = nodes.NODE_CLASS_MAPPINGS["VAEDecode"]()
createvideo = nodes.NODE_CLASS_MAPPINGS["CreateVideo"]()
savevideo = nodes.NODE_CLASS_MAPPINGS["SaveVideo"]()
cliploader_38 = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu")
unetloader_37_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
unetloader_91_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
vaeloader_39 = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
clipvisionloader_49 = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")
loraloadermodelonly_94_high = loraloadermodelonly.load_lora_model_only(lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unetloader_91_high_noise, 0))
loraloadermodelonly_95_low = loraloadermodelonly.load_lora_model_only(lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unetloader_37_low_noise, 0))
modelsamplingsd3_93_low = modelsamplingsd3.patch(shift=8, model=get_value_at_index(loraloadermodelonly_95_low, 0))
pathchsageattentionkj_98_low = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(modelsamplingsd3_93_low, 0))
modelsamplingsd3_79_high = modelsamplingsd3.patch(shift=8, model=get_value_at_index(loraloadermodelonly_94_high, 0))
pathchsageattentionkj_96_high = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(modelsamplingsd3_79_high, 0))
model_loaders = [cliploader_38, unetloader_37_low_noise, unetloader_91_high_noise, vaeloader_39, clipvisionloader_49, loraloadermodelonly_94_high, loraloadermodelonly_95_low]
valid_models = [getattr(loader[0], 'patcher', loader[0]) for loader in model_loaders if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)]
model_management.load_models_gpu(valid_models)
# --- App Logic ---
def calculate_dimensions(width, height):
if width == height: return 480, 480
if width > height: new_width, new_height = 832, int(height * (832 / width))
else: new_height, new_width = 832, int(width * (832 / height))
return (new_width // 16) * 16, (new_height // 16) * 16
@spaces.GPU(duration=120)
def generate_video(prompt, first_image_path, last_image_path, duration_seconds, progress=gr.Progress(track_tqdm=True)):
# Create a temporary directory for resized images
temp_dir = "input"
os.makedirs(temp_dir, exist_ok=True)
with torch.inference_mode():
# --- Python Image Preprocessing using Pillow ---
print("Preprocessing images with Pillow...")
with Image.open(first_image_path) as img:
orig_width, orig_height = img.size
target_width, target_height = calculate_dimensions(orig_width, orig_height)
# Resize first image
with Image.open(first_image_path) as img:
img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
resized_first_path = os.path.join(temp_dir, f"first_frame_resized_{uuid.uuid4().hex}.png")
print(resized_first_path)
img_resized.save(resized_first_path)
# Resize second image to match the target dimensions
with Image.open(last_image_path) as img:
img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
resized_last_path = os.path.join(temp_dir, f"last_frame_resized_{uuid.uuid4().hex}.png")
print(resized_last_path)
img_resized.save(resized_last_path)
print(f"Images resized to {target_width}x{target_height} and saved temporarily.")
# --- End Preprocessing ---
FPS, MAX_FRAMES = 16, 81
length_in_frames = max(1, min(int(duration_seconds * FPS), MAX_FRAMES))
print(f"Requested duration: {duration_seconds}s. Calculated frames: {length_in_frames}")
# Load the pre-processed images into ComfyUI
loaded_first_image = loadimage.load_image(image=os.path.basename(resized_first_path))
loaded_last_image = loadimage.load_image(image=os.path.basename(resized_last_path))
cliptextencode_6 = cliptextencode.encode(text=prompt, clip=get_value_at_index(cliploader_38, 0))
cliptextencode_7_negative = cliptextencode.encode(text="low quality, worst quality, jpeg artifacts, ugly, deformed, blurry", clip=get_value_at_index(cliploader_38, 0))
clipvisionencode_51 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(loaded_first_image, 0))
clipvisionencode_87 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(loaded_last_image, 0))
wanfirstlastframetovideo_83 = wanfirstlastframetovideo.EXECUTE_NORMALIZED(width=target_width, height=target_height, length=length_in_frames, batch_size=1, positive=get_value_at_index(cliptextencode_6, 0), negative=get_value_at_index(cliptextencode_7_negative, 0), vae=get_value_at_index(vaeloader_39, 0), clip_vision_start_image=get_value_at_index(clipvisionencode_51, 0), clip_vision_end_image=get_value_at_index(clipvisionencode_87, 0), start_image=get_value_at_index(loaded_first_image, 0), end_image=get_value_at_index(loaded_last_image, 0))
ksampler_positive = get_value_at_index(wanfirstlastframetovideo_83, 0)
ksampler_negative = get_value_at_index(wanfirstlastframetovideo_83, 1)
ksampler_latent = get_value_at_index(wanfirstlastframetovideo_83, 2)
ksampleradvanced_101 = ksampleradvanced.sample(add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, return_with_leftover_noise="enable", model=get_value_at_index(pathchsageattentionkj_96_high, 0), positive=ksampler_positive, negative=ksampler_negative, latent_image=ksampler_latent)
ksampleradvanced_102 = ksampleradvanced.sample(add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, return_with_leftover_noise="disable", model=get_value_at_index(pathchsageattentionkj_98_low, 0), positive=ksampler_positive, negative=ksampler_negative, latent_image=get_value_at_index(ksampleradvanced_101, 0))
vaedecode_8 = vaedecode.decode(samples=get_value_at_index(ksampleradvanced_102, 0), vae=get_value_at_index(vaeloader_39, 0))
createvideo_104 = createvideo.create_video(fps=16, images=get_value_at_index(vaedecode_8, 0))
savevideo_103 = savevideo.save_video(filename_prefix="ComfyUI_Video", format="mp4", codec="h264", video=get_value_at_index(createvideo_104, 0))
print("** DEBUG ** ", savevideo_103)
return f"output/{savevideo_103['ui']['images'][0]['filename']}"
# --- Gradio Interface ---
with gr.Blocks() as app:
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) on ZeroGPU")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Prompt", value="a man dancing in the street, cinematic")
duration_slider = gr.Slider(minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Video Duration (seconds)")
with gr.Row():
first_image = gr.Image(label="First Frame", type="filepath")
last_image = gr.Image(label="Last Frame", type="filepath")
generate_btn = gr.Button("Generate Video")
with gr.Column(scale=2):
output_video = gr.Video(label="Generated Video")
generate_btn.click(fn=generate_video, inputs=[prompt_input, first_image, last_image, duration_slider], outputs=[output_video])
gr.Examples(examples=[["a beautiful woman, cinematic", "examples/start.png", "examples/end.png", 2.5]], inputs=[prompt_input, first_image, last_image, duration_slider])
if __name__ == "__main__":
if not os.path.exists("examples"): os.makedirs("examples")
if not os.path.exists("examples/start.png"): Image.new('RGB', (512, 512), color='red').save('examples/start.png')
if not os.path.exists("examples/end.png"): Image.new('RGB', (512, 512), color='blue').save('examples/end.png')
# Set the input directory for LoadImage to find the temp files
import folder_paths
folder_paths.add_model_folder_path("input", "temp_resized")
app.launch()