Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
import random | |
import sys | |
import tempfile | |
from typing import Sequence, Mapping, Any, Union | |
import spaces | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
from comfy import model_management | |
def hf_hub_download_local(repo_id, filename, local_dir, **kwargs): | |
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) | |
os.makedirs(local_dir, exist_ok=True) | |
base_filename = os.path.basename(filename) | |
target_path = os.path.join(local_dir, base_filename) | |
if os.path.exists(target_path) or os.path.islink(target_path): | |
os.remove(target_path) | |
os.symlink(downloaded_path, target_path) | |
return target_path | |
# --- Model Downloads --- | |
print("Downloading models from Hugging Face Hub...") | |
text_encoder_repo = hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders") | |
print(text_encoder_repo) | |
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") | |
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") | |
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae") | |
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision") | |
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras") | |
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras") | |
print("Downloads complete.") | |
# --- Boilerplate code from the original script --- | |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
"""Returns the value at the given index of a sequence or mapping. | |
If the object is a sequence (like list or string), returns the value at the given index. | |
If the object is a mapping (like a dictionary), returns the value at the index-th key. | |
Some return a dictionary, in these cases, we look for the "results" key | |
Args: | |
obj (Union[Sequence, Mapping]): The object to retrieve the value from. | |
index (int): The index of the value to retrieve. | |
Returns: | |
Any: The value at the given index. | |
Raises: | |
IndexError: If the index is out of bounds for the object and the object is not a mapping. | |
""" | |
try: | |
return obj[index] | |
except KeyError: | |
# This is a fallback for custom node outputs that might be dictionaries | |
if isinstance(obj, Mapping) and "result" in obj: | |
return obj["result"][index] | |
raise | |
def find_path(name: str, path: str = None) -> str: | |
""" | |
Recursively looks at parent folders starting from the given path until it finds the given name. | |
Returns the path as a Path object if found, or None otherwise. | |
""" | |
if path is None: | |
path = os.getcwd() | |
if name in os.listdir(path): | |
path_name = os.path.join(path, name) | |
print(f"'{name}' found: {path_name}") | |
return path_name | |
parent_directory = os.path.dirname(path) | |
if parent_directory == path: | |
return None | |
return find_path(name, parent_directory) | |
def add_comfyui_directory_to_sys_path() -> None: | |
""" | |
Add 'ComfyUI' to the sys.path | |
""" | |
comfyui_path = find_path("ComfyUI") | |
if comfyui_path is not None and os.path.isdir(comfyui_path): | |
sys.path.append(comfyui_path) | |
print(f"'{comfyui_path}' added to sys.path") | |
else: | |
print("Could not find ComfyUI directory. Please run from a parent folder of ComfyUI.") | |
def add_extra_model_paths() -> None: | |
""" | |
Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path. | |
""" | |
try: | |
from main import load_extra_path_config | |
except ImportError: | |
print( | |
"Could not import load_extra_path_config from main.py. This might be okay if you don't use it." | |
) | |
return | |
extra_model_paths = find_path("extra_model_paths.yaml") | |
if extra_model_paths is not None: | |
load_extra_path_config(extra_model_paths) | |
else: | |
print("Could not find an optional 'extra_model_paths.yaml' config file.") | |
def import_custom_nodes() -> None: | |
"""Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS | |
This function sets up a new asyncio event loop, initializes the PromptServer, | |
creates a PromptQueue, and initializes the custom nodes. | |
""" | |
import asyncio | |
import execution | |
from nodes import init_extra_nodes | |
import server | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
server_instance = server.PromptServer(loop) | |
execution.PromptQueue(server_instance) | |
loop.run_until_complete(init_extra_nodes(init_custom_nodes=True)) | |
# --- Model Loading and Caching --- | |
MODELS_AND_NODES = {} | |
print("Setting up ComfyUI paths...") | |
add_comfyui_directory_to_sys_path() | |
add_extra_model_paths() | |
print("Importing custom nodes...") | |
import_custom_nodes() | |
# Now that paths are set up, we can import from nodes | |
from nodes import NODE_CLASS_MAPPINGS | |
global folder_paths # Make folder_paths globally accessible | |
import folder_paths | |
print("Loading models into memory. This may take a few minutes...") | |
# Load Text-to-Image models (CLIP, UNETs, VAE) | |
cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]() | |
MODELS_AND_NODES["clip"] = cliploader.load_clip( | |
clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu" | |
) | |
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
unet_low_noise = unetloader.load_unet( | |
unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", | |
weight_dtype="default", | |
) | |
unet_high_noise = unetloader.load_unet( | |
unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", | |
weight_dtype="default", | |
) | |
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors") | |
# Load LoRAs | |
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() | |
MODELS_AND_NODES["model_low_noise"] = loraloadermodelonly.load_lora_model_only( | |
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", | |
strength_model=0.8, | |
model=get_value_at_index(unet_low_noise, 0), | |
) | |
MODELS_AND_NODES["model_high_noise"] = loraloadermodelonly.load_lora_model_only( | |
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", | |
strength_model=0.8, | |
model=get_value_at_index(unet_high_noise, 0), | |
) | |
# Load Vision model | |
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() | |
MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip( | |
clip_name="clip_vision_h.safetensors" | |
) | |
# Instantiate all required node classes | |
MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]() | |
MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() | |
MODELS_AND_NODES["ModelSamplingSD3"] = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]() | |
MODELS_AND_NODES["PathchSageAttentionKJ"] = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]() | |
MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]() | |
MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]() | |
MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]() | |
MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]() | |
MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]() | |
print("Pre-loading main models onto GPU...") | |
model_loaders = [ | |
MODELS_AND_NODES["clip"], | |
MODELS_AND_NODES["vae"], | |
MODELS_AND_NODES["model_low_noise"], # This is the UNET + LoRA | |
MODELS_AND_NODES["model_high_noise"], # This is the other UNET + LoRA | |
MODELS_AND_NODES["clip_vision"], | |
] | |
model_management.load_models_gpu([ | |
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders | |
]) | |
print("All models loaded successfully!") | |
# --- Main Video Generation Logic --- | |
def generate_video( | |
start_image_pil, | |
end_image_pil, | |
prompt, | |
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", | |
duration=2, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
""" | |
The main function to generate a video based on user inputs. | |
This function is called every time the user clicks the 'Generate' button. | |
""" | |
FPS = 16 | |
num_frames = max(2, int(duration * FPS)) | |
clip = MODELS_AND_NODES["clip"] | |
vae = MODELS_AND_NODES["vae"] | |
model_low_noise = MODELS_AND_NODES["model_low_noise"] | |
model_high_noise = MODELS_AND_NODES["model_high_noise"] | |
clip_vision = MODELS_AND_NODES["clip_vision"] | |
cliptextencode = MODELS_AND_NODES["CLIPTextEncode"] | |
loadimage = MODELS_AND_NODES["LoadImage"] | |
clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"] | |
modelsamplingsd3 = MODELS_AND_NODES["ModelSamplingSD3"] | |
pathchsageattentionkj = MODELS_AND_NODES["PathchSageAttentionKJ"] | |
wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"] | |
ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"] | |
vaedecode = MODELS_AND_NODES["VAEDecode"] | |
createvideo = MODELS_AND_NODES["CreateVideo"] | |
savevideo = MODELS_AND_NODES["SaveVideo"] | |
# Save uploaded images to temporary files | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as start_file, \ | |
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as end_file: | |
start_image_pil.save(start_file.name) | |
end_image_pil.save(end_file.name) | |
start_image_path = start_file.name | |
end_image_path = end_file.name | |
with torch.inference_mode(): | |
progress(0.1, desc="Encoding text and images...") | |
# --- Workflow execution --- | |
positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0)) | |
negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0)) | |
start_image_loaded = loadimage.load_image(image=start_image_path) | |
end_image_loaded = loadimage.load_image(image=end_image_path) | |
clip_vision_encoded_start = clipvisionencode.encode( | |
crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0) | |
) | |
clip_vision_encoded_end = clipvisionencode.encode( | |
crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0) | |
) | |
progress(0.2, desc="Preparing initial latents...") | |
initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED( | |
width=480, height=480, length=num_frames, batch_size=1, | |
positive=get_value_at_index(positive_conditioning, 0), | |
negative=get_value_at_index(negative_conditioning, 0), | |
vae=get_value_at_index(vae, 0), | |
clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0), | |
clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0), | |
start_image=get_value_at_index(start_image_loaded, 0), | |
end_image=get_value_at_index(end_image_loaded, 0), | |
) | |
progress(0.3, desc="Patching models...") | |
model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0)) | |
model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0)) | |
model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0)) | |
model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0)) | |
progress(0.5, desc="Running KSampler (Step 1/2)...") | |
latent_step1 = ksampleradvanced.sample( | |
add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, | |
sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, | |
return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0), | |
positive=get_value_at_index(initial_latents, 0), | |
negative=get_value_at_index(initial_latents, 1), | |
latent_image=get_value_at_index(initial_latents, 2), | |
) | |
progress(0.7, desc="Running KSampler (Step 2/2)...") | |
latent_step2 = ksampleradvanced.sample( | |
add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, | |
sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, | |
return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0), | |
positive=get_value_at_index(initial_latents, 0), | |
negative=get_value_at_index(initial_latents, 1), | |
latent_image=get_value_at_index(latent_step1, 0), | |
) | |
progress(0.8, desc="Decoding VAE...") | |
decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0)) | |
progress(0.9, desc="Creating and saving video...") | |
video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0)) | |
# Save the video to ComfyUI's output directory | |
save_result = savevideo.save_video( | |
filename_prefix="GradioVideo", format="mp4", codec="h264", | |
video=get_value_at_index(video_data, 0), | |
) | |
progress(1.0, desc="Done!") | |
return f"output/{save_result['ui']['images'][0]['filename']}" | |
css = ''' | |
.fillable{max-width: 980px !important} | |
.dark .progress-text {color: white} | |
''' | |
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: | |
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") | |
gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) on ZeroGPU") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
start_image = gr.Image(type="pil", label="Start Frame") | |
end_image = gr.Image(type="pil", label="End Frame") | |
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images", value="transition") | |
with gr.Accordion("Advanced Settings", open=False): | |
duration = gr.Slider( | |
minimum=1.0, | |
maximum=5.0, | |
value=2.0, | |
step=0.1, | |
label="Video Duration (seconds)", | |
info="Longer videos take longer to generate" | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", | |
visible=False | |
) | |
generate_button = gr.Button("Generate Video", variant="primary") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video") | |
generate_button.click( | |
fn=generate_video, | |
inputs=[start_image, end_image, prompt, negative_prompt, duration], | |
outputs=output_video | |
) | |
gr.Examples( | |
examples=[ | |
["poli_tower.png", "tower_takes_off.png", "the man turns"], | |
["capybara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"], | |
], | |
inputs=[start_image, end_image, prompt], | |
outputs=output_video, | |
fn=generate_video, | |
cache_examples="lazy", | |
) | |
if __name__ == "__main__": | |
app = create_gradio_app() | |
app.launch(share=True) |