Wan2.1 / app.py
chaojiemao's picture
Update app.py
965dd18 verified
raw
history blame
11.6 kB
import copy
import os
import random
os.system('pip install dashscope')
import gradio as gr
import dashscope
from dashscope import VideoSynthesis
from examples import t2v_examples, i2v_examples
import time
DASHSCOPE_API_KEY = os.getenv('DASHSCOPE_API_KEY')
dashscope.api_key = DASHSCOPE_API_KEY
KEEP_SUCCESS_TASK = 3600 * 10
KEEP_RUNING_TASK = 3600 * 2
# the total running task number in 1800 seconds
LIMIT_RUNING_TASK = 10
def t2v_generation(prompt, resolution, watermark_wanx, seed = -1):
seed = seed if seed >= 0 else random.randint(0, 2147483647)
if not allow_task_num():
gr.Info(f"Warning: The number of running tasks is too large, please wait for a while.")
return None, gr.Button(visible=True)
try:
rsp = VideoSynthesis.call(model="wanx2.1-t2v-plus", prompt=prompt, seed=seed,
watermark_wanx=watermark_wanx, size=resolution)
video_url = rsp.output.video_url
return video_url, gr.Button(visible=True)
except Exception as e:
gr.Warning(f"Warning: {e}")
return None, gr.Button(visible=True)
def t2v_generation_async(prompt, size, watermark_wanx, seed = -1):
print(seed)
seed = seed if seed >= 0 else random.randint(0, 2147483647)
print(seed)
if not allow_task_num():
gr.Info(f"Warning: The number of running tasks is too large, please wait for a while.")
return None, False, gr.Button(visible=True)
try:
rsp = VideoSynthesis.async_call(model="wanx2.1-t2v-plus",
prompt=prompt,
size=size,
seed=seed,
watermark_wanx=watermark_wanx)
task_id = rsp.output.task_id
status = False
return task_id, status, gr.Button(visible=False)
except Exception as e:
gr.Warning(f"Warning: {e}")
return None, True, gr.Button()
def i2v_generation(prompt, image, watermark_wanx, seed = -1):
seed = seed if seed >= 0 else random.randint(0, 2147483647)
video_url = None
try:
rsp = VideoSynthesis.call(model="wanx2.1-i2v-plus", prompt=prompt, img_url= image,
seed = seed,
watermark_wanx=watermark_wanx
)
video_url = rsp.output.video_url
except Exception as e:
gr.Warning(f"Warning: {e}")
return video_url
def i2v_generation_async(prompt, image, watermark_wanx, seed = -1):
seed = seed if seed >= 0 else random.randint(0, 2147483647)
if not allow_task_num():
gr.Info(f"Warning: The number of running tasks is too large, please wait for a while.")
return "", None, gr.Button(visible=True)
try:
rsp = VideoSynthesis.async_call(model="wanx2.1-i2v-plus", prompt=prompt, seed=seed,
img_url= image, watermark_wanx=watermark_wanx)
print(rsp)
task_id = rsp.output.task_id
status = False
return task_id, status, gr.Button(visible=False)
except Exception as e:
gr.Warning(f"Warning: {e}")
return "", None, gr.Button()
def get_result_with_task_id(task_id):
if task_id == "": return True, None
try:
rsp = VideoSynthesis.fetch(task = task_id)
print(rsp)
if rsp.output.task_status == "FAILED":
gr.Info(f"Warning: task running {rsp.output.task_status}")
status = True
video_url = None
else:
video_url = rsp.output.video_url
video_url = video_url if video_url != "" else None
status = video_url is not None
except:
video_url = None
status = False
return status, None if video_url=="" else video_url
# return True, "https://dashscope-result-wlcb.oss-cn-wulanchabu.aliyuncs.com/1d/f8/20250220/e7d3f375/ccc590a2-7e90-4d92-84bc-22668db42979.mp4?Expires=1740137152&OSSAccessKeyId=LTAI5tQZd8AEcZX6KZV4G8qL&Signature=i3S3jA5FY6XYfvzZNHnvQiPzZSw%3D"
task_status = {}
def allow_task_num():
num = 0
for task_id in task_status:
if not task_status[task_id]["status"] and task_status[task_id]["time"] + 1800 > time.time():
num += 1
return num < LIMIT_RUNING_TASK
def clean_task_status():
# clean the task over 1800 seconds
for task_id in copy.deepcopy(task_status):
if task_id == "": continue
# finished task, keep 3600 seconds
if task_status[task_id]["status"]:
if task_status[task_id]["time"] + KEEP_SUCCESS_TASK < time.time():
task_status.pop(task_id)
else:
# clean the task over 3600 * 2 seconds
if task_status[task_id]["time"] + KEEP_RUNING_TASK < time.time():
task_status.pop(task_id)
def cost_time(task_id):
if task_id in task_status and not task_status[task_id]["status"]:
et = time.time() - task_status[task_id]["time"]
return f"{et:.2f}"
else:
return gr.Textbox()
def get_process_bar(task_id, status):
clean_task_status()
if task_id not in task_status:
task_status[task_id] = {
"value": 0 if not task_id == "" else 100,
"status": status if not task_id == "" else True,
"time": time.time(),
"url": None
}
if not task_status[task_id]["status"]:
# only when > 50% do check status
if task_status[task_id]["value"] >= 10 and task_status[task_id]["value"] % 5 == 0:
status, video_url = get_result_with_task_id(task_id)
else:
status, video_url = False, None
task_status[task_id]["status"] = status
task_status[task_id]["url"] = video_url
if task_status[task_id]["status"]:
task_status[task_id]["value"] = 100
else:
task_status[task_id]["value"] += 1
if task_status[task_id]["value"] >= 100 and not task_status[task_id]["status"]:
task_status[task_id]["value"] = 95
# print(task_id, task_status[task_id], task_status)
value = task_status[task_id]["value"]
return gr.Slider(label= f"({value}%)Generating" if value%2==1 else f"({value}%)Generating.....", value=value)
with gr.Blocks() as demo:
gr.HTML("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
WanX
</div>
<div style="text-align: center;">
<a href="https://huggingface.co/WanX-AI/WanX2.1-T2V-1.3B">WanX2.1-T2V-1.3B</a> |
<a href="https://huggingface.co/WanX-AI/WanX2.1-T2V-14B">WanX2.1-T2V-14B</a> |
<a href="https://huggingface.co/WanX-AI/WanX2.1-I2V-14B-480P">WanX2.1-I2V-14B-480P</a> |
<a href="https://huggingface.co/WanX-AI/WanX2.1-I2V-14B-720P">WanX2.1-I2V-14B-720P</a>
</div>
""")
task_id = gr.State(value="")
status = gr.State(value=False)
task = gr.State(value="t2v")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Tabs():
# Text to Video Tab
with gr.TabItem("Text to Video") as t2v_tab:
with gr.Row():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
lines=19,
)
with gr.Row():
resolution = gr.Dropdown(
label="Resolution",
choices=["1280*720", "960*960", "720*1280", "1088*832", "832*1088"],
value="1280*720",
)
with gr.Row():
run_t2v_button = gr.Button("Generate Video")
# Image to Video Tab
with gr.TabItem("Image to Video") as i2v_tab:
with gr.Row():
with gr.Column():
img2vid_image = gr.Image(
type="filepath",
label="Upload Input Image",
elem_id="image_upload",
)
img2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
value="",
lines=5,
)
with gr.Row():
run_i2v_button = gr.Button("Generate Video")
with gr.Column():
with gr.Row():
result_gallery = gr.Video(label='WanX Generated Video',
interactive=False,
height=500)
with gr.Row():
watermark_wanx = gr.Checkbox(label="Watermark", value=True, container=False)
seed = gr.Number(label="Seed", value=-1, container=True)
cost_time = gr.Number(label="Cost Time(secs)", value=cost_time, interactive=False,
every=2, inputs=[task_id], container=True)
process_bar = gr.Slider(show_label=True, label="", value=get_process_bar, maximum=100,
interactive=True, every=3, inputs=[task_id, status], container=True)
fake_video = gr.Video(label='WanX Examples', visible=False, interactive=False)
with gr.Row(visible=True) as t2v_eg:
gr.Examples(t2v_examples,
inputs=[txt2vid_prompt, result_gallery],
outputs=[result_gallery])
with gr.Row(visible=False) as i2v_eg:
gr.Examples(i2v_examples,
inputs=[img2vid_prompt, img2vid_image, result_gallery],
outputs=[result_gallery])
def process_change(task_id, task):
status = task_status[task_id]["status"]
if status:
video_url = task_status[task_id]["url"]
ret_t2v_btn = gr.Button(visible=True) if task == 't2v' else gr.Button()
ret_i2v_btn = gr.Button(visible=True) if task == 'i2v' else gr.Button()
return gr.Video(value=video_url), ret_t2v_btn, ret_i2v_btn
return gr.Video(value=None), gr.Button(), gr.Button()
process_bar.change(process_change, inputs=[task_id, task],
outputs=[result_gallery, run_t2v_button, run_i2v_button])
def switch_i2v_tab():
return gr.Row(visible=False), gr.Row(visible=True), "i2v"
def switch_t2v_tab():
return gr.Row(visible=True), gr.Row(visible=False), "t2v"
i2v_tab.select(switch_i2v_tab, outputs=[t2v_eg, i2v_eg, task])
t2v_tab.select(switch_t2v_tab, outputs=[t2v_eg, i2v_eg, task])
run_t2v_button.click(
fn=t2v_generation_async,
inputs=[txt2vid_prompt, resolution, watermark_wanx, seed],
outputs=[task_id, status, run_t2v_button],
)
run_i2v_button.click(
fn=i2v_generation_async,
inputs=[img2vid_prompt, img2vid_image, watermark_wanx, seed],
outputs=[task_id, status, run_i2v_button],
)
demo.queue(max_size=10)
demo.launch()