# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse import os import sys import datetime import imageio import numpy as np import torch import gradio as gr is_shared_ui = True if "fffiloni/Wan2.1-VACE-1.3B" in os.environ['SPACE_ID'] else False is_gpu_associated = torch.cuda.is_available() from huggingface_hub import snapshot_download if not is_shared_ui and is_gpu_associated: snapshot_download( repo_id = "Wan-AI/Wan2.1-VACE-1.3B", local_dir = "./models/Wan2.1-VACE-1.3B" ) sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan import WanVace, WanVaceMP from wan.configs import WAN_CONFIGS, SIZE_CONFIGS css = """ div#warning-duplicate { background-color: #ebf5ff; padding: 0 16px 16px; margin: 20px 0; color: #030303!important; } div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p { color: #0f4592!important; } div#warning-duplicate strong { color: #0f4592; } p.actions { display: flex; align-items: center; margin: 20px 0; } div#warning-duplicate .actions a { display: inline-block; margin-right: 10px; } div#warning-setgpu { background-color: #fff4eb; padding: 0 16px 16px; margin: 20px 0; color: #030303!important; } div#warning-setgpu > .gr-prose > h2, div#warning-setgpu > .gr-prose > p { color: #92220f!important; } div#warning-setgpu a, div#warning-setgpu b { color: #91230f; } div#warning-setgpu p.actions > a { display: inline-block; background: #1f1f23; border-radius: 40px; padding: 6px 24px; color: antiquewhite; text-decoration: none; font-weight: 600; font-size: 1.2em; } div#warning-ready { background-color: #ecfdf5; padding: 0 16px 16px; margin: 20px 0; color: #030303!important; } div#warning-ready > .gr-prose > h2, div#warning-ready > .gr-prose > p { color: #057857!important; } .custom-color { color: #030303 !important; } """ class FixedSizeQueue: def __init__(self, max_size): self.max_size = max_size self.queue = [] def add(self, item): self.queue.insert(0, item) if len(self.queue) > self.max_size: self.queue.pop() def get(self): return self.queue def __repr__(self): return str(self.queue) class VACEInference: def __init__(self, cfg, skip_load=False, gallery_share=False, gallery_share_limit=5): self.cfg = cfg self.save_dir = cfg.save_dir self.gallery_share = gallery_share self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit) if not skip_load: if not args.mp: self.pipe = WanVace( config=WAN_CONFIGS[cfg.model_name], checkpoint_dir=cfg.ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) else: self.pipe = WanVaceMP( config=WAN_CONFIGS[cfg.model_name], checkpoint_dir=cfg.ckpt_dir, use_usp=True, ulysses_size=cfg.ulysses_size, ring_size=cfg.ring_size ) def create_ui(self, *args, **kwargs): gr.Markdown("# VACE-WAN 1.3B Demo") gr.Markdown("All-in-One Video Creation and Editing") gr.HTML("""
Duplicate this Space
""") with gr.Column(): if is_shared_ui: top_description = gr.HTML(f'''

Attention: this Space need to be duplicated to work

To make it work, duplicate the Space and run it on your own profile using a private GPU (L40s recommended).
A L40s costs US$1.80/h.

Duplicate this Space to start experimenting with this demo

''', elem_id="warning-duplicate") else: if(is_gpu_associated): top_description = gr.HTML(f'''

You have successfully associated a GPU to this Space 🎉

You will be billed by the minute from when you activated the GPU until when it is turned off.

''', elem_id="warning-ready") else: top_description = gr.HTML(f'''

You have successfully duplicated the MimicMotion Space 🎉

There's only one step left before you can properly play with this demo: attribute a GPU to it (via the Settings tab) and run the app below. You will be billed by the minute from when you activate the GPU until when it is turned off.

🔥   Set recommended GPU

''', elem_id="warning-setgpu") with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): self.src_video = gr.Video( label="src_video", sources=['upload'], value=None, interactive=True) with gr.Column(scale=1, min_width=0): self.src_mask = gr.Video( label="src_mask", sources=['upload'], value=None, interactive=True) # with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): with gr.Row(equal_height=True): self.src_ref_image_1 = gr.Image(label='src_ref_image_1', height=200, interactive=True, type='filepath', image_mode='RGB', sources=['upload'], elem_id="src_ref_image_1", format='png') self.src_ref_image_2 = gr.Image(label='src_ref_image_2', height=200, interactive=True, type='filepath', image_mode='RGB', sources=['upload'], elem_id="src_ref_image_2", format='png') self.src_ref_image_3 = gr.Image(label='src_ref_image_3', height=200, interactive=True, type='filepath', image_mode='RGB', sources=['upload'], elem_id="src_ref_image_3", format='png') with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1): self.prompt = gr.Textbox( show_label=False, placeholder="positive_prompt_input", elem_id='positive_prompt', container=True, autofocus=True, elem_classes='type_row', visible=True, lines=2) self.negative_prompt = gr.Textbox( show_label=False, value="Bright and saturated tones, overexposed, static, unclear details, subtitles, style, work, painting, frame, still, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, misshapen limbs, fused fingers, motionless frame, cluttered background, three legs, crowded background, walking backwards.", placeholder="negative_prompt_input", elem_id='negative_prompt', container=True, autofocus=False, elem_classes='type_row', visible=True, interactive=True, lines=1) # with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): with gr.Row(equal_height=True): self.shift_scale = gr.Slider( label='shift_scale', minimum=0.0, maximum=100.0, step=1.0, value=16.0, interactive=True) self.sample_steps = gr.Slider( label='sample_steps', minimum=1, maximum=100, step=1, value=25, interactive=False if is_shared_ui else True) self.context_scale = gr.Slider( label='context_scale', minimum=0.0, maximum=2.0, step=0.1, value=1.0, interactive=True) self.guide_scale = gr.Slider( label='guide_scale', minimum=1, maximum=10, step=0.5, value=5.0, interactive=True) self.infer_seed = gr.Slider(minimum=-1, maximum=10000000, value=2025, label="Seed") # with gr.Accordion(label="Usable without source video", open=False): with gr.Row(equal_height=True): self.output_height = gr.Textbox( label='resolutions_height', value=480, #value=720, interactive=True) self.output_width = gr.Textbox( label='resolutions_width', value=832, #value=1280, interactive=True) self.frame_rate = gr.Textbox( label='frame_rate', value=16, interactive=True) self.num_frames = gr.Textbox( label='num_frames', value=81, interactive=True) # with gr.Row(equal_height=True): with gr.Column(scale=5): self.generate_button = gr.Button( value='Run', elem_classes='type_row', elem_id='generate_button', visible=True, interactive = False if is_shared_ui else True ) with gr.Column(scale=1): self.refresh_button = gr.Button(value='\U0001f504') # 🔄 # self.output_gallery = gr.Gallery( label="output_gallery", value=[], interactive=False, allow_preview=True, preview=True) def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames, progress=gr.Progress(track_tqdm=True)): output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames) src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if x is not None] src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video], [src_mask], [src_ref_images], num_frames=num_frames, image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], device=self.pipe.device) video = self.pipe.generate( prompt, src_video, src_mask, src_ref_images, size=(output_width, output_height), context_scale=context_scale, shift=shift_scale, sampling_steps=sample_steps, guide_scale=guide_scale, n_prompt=negative_prompt, seed=infer_seed, offload_model=True) name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now()) base_save_dir = './output' save_dir_path = os.path.join(base_save_dir, name) # Create the directory os.makedirs(save_dir_path, exist_ok=True) print(f"✅ Folder created: {save_dir_path}") video_path = os.path.join(save_dir_path, f'cur_gallery_{name}.mp4') video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) try: writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1) for frame in video_frames: writer.append_data(frame) writer.close() print(video_path) except Exception as e: raise gr.Error(f"Video save error: {e}") #if self.gallery_share: # self.gallery_share_data.add(video_path) # return self.gallery_share_data.get() else: return [video_path] def set_callbacks(self, **kwargs): self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames] self.gen_outputs = [self.output_gallery] self.generate_button.click(self.generate, inputs=self.gen_inputs, outputs=self.gen_outputs, queue=True) self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery]) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n') parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860) parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0') parser.add_argument('--root_path', dest='root_path', help='', default=None) parser.add_argument('--save_dir', dest='save_dir', help='', default='cache') parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",) if not is_shared_ui and is_gpu_associated: parser.add_argument("--model_name", type=str, default="vace-1.3B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.") parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") parser.add_argument( "--ckpt_dir", type=str, # default='models/VACE-Wan2.1-1.3B-Preview', default='models/Wan2.1-VACE-1.3B/', help="The path to the checkpoint directory.", ) parser.add_argument( "--offload_to_cpu", action="store_true", help="Offloading unnecessary computations to CPU.", ) args = parser.parse_args() if not os.path.exists(args.save_dir): os.makedirs(args.save_dir, exist_ok=True) with gr.Blocks(css=css) as demo: if not is_shared_ui and is_gpu_associated: skip_load = False else: skip_load = True infer_gr = VACEInference(args, skip_load=skip_load, gallery_share=True, gallery_share_limit=5) infer_gr.create_ui() infer_gr.set_callbacks() allowed_paths = [args.save_dir] demo.queue(status_update_rate=1).launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, allowed_paths=allowed_paths, show_error=True, debug=True)