kelseye commited on
Commit
0b909d1
·
verified ·
1 Parent(s): 8c52de3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import oss2
4
+ import sys
5
+ import uuid
6
+ import shutil
7
+ import time
8
+ import gradio as gr
9
+ import requests
10
+
11
+ os.system("pip install dashscope")
12
+ import dashscope
13
+ from dashscope.utils.oss_utils import check_and_upload_local
14
+
15
+ DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
16
+ dashscope.api_key = DASHSCOPE_API_KEY
17
+
18
+
19
+ class WanAnimateApp:
20
+ def __init__(self, url, get_url):
21
+ self.url = url
22
+ self.get_url = get_url
23
+
24
+ def predict(
25
+ self,
26
+ ref_img,
27
+ video,
28
+ model_id,
29
+ model,
30
+ ):
31
+ # Upload files to OSS if needed and get URLs
32
+ _, image_url = check_and_upload_local(model_id, ref_img, DASHSCOPE_API_KEY)
33
+ _, video_url = check_and_upload_local(model_id, video, DASHSCOPE_API_KEY)
34
+
35
+ # Prepare the request payload
36
+ payload = {
37
+ "model": model_id,
38
+ "input": {
39
+ "image_url": image_url,
40
+ "video_url": video_url
41
+ },
42
+ "parameters": {
43
+ "check_image": True,
44
+ "mode": model,
45
+ }
46
+ }
47
+
48
+ # Set up headers
49
+ headers = {
50
+ "X-DashScope-Async": "enable",
51
+ "X-DashScope-OssResourceResolve": "enable",
52
+ "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
53
+ "Content-Type": "application/json"
54
+ }
55
+
56
+ # Make the initial API request
57
+ url = self.url
58
+ response = requests.post(url, json=payload, headers=headers)
59
+
60
+ # Check if request was successful
61
+ if response.status_code != 200:
62
+ raise Exception(f"Initial request failed with status code {response.status_code}: {response.text}")
63
+
64
+ # Get the task ID from response
65
+ result = response.json()
66
+ task_id = result.get("output", {}).get("task_id")
67
+ if not task_id:
68
+ raise Exception("Failed to get task ID from response")
69
+
70
+ # Poll for results
71
+ get_url = f"{self.get_url}/{task_id}"
72
+ headers = {
73
+ "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
74
+ "Content-Type": "application/json"
75
+ }
76
+
77
+ while True:
78
+ response = requests.get(get_url, headers=headers)
79
+ if response.status_code != 200:
80
+ raise Exception(f"Failed to get task status: {response.status_code}: {response.text}")
81
+
82
+ result = response.json()
83
+ print(result)
84
+ task_status = result.get("output", {}).get("task_status")
85
+
86
+ if task_status == "SUCCEEDED":
87
+ # Task completed successfully, return video URL
88
+ video_url = result["output"]["results"]["video_url"]
89
+ return video_url, "SUCCEEDED"
90
+ elif task_status == "FAILED":
91
+ # Task failed, raise an exception with error message
92
+ error_msg = result.get("output", {}).get("message", "Unknown error")
93
+ code_msg = result.get("output", {}).get("code", "Unknown code")
94
+ print(f"\n\nTask failed: {error_msg} Code: {code_msg} TaskId: {task_id}\n\n")
95
+ return None, f"Task failed: {error_msg} Code: {code_msg} TaskId: {task_id}"
96
+ # raise Exception(f"Task failed: {error_msg} TaskId: {task_id}")
97
+ else:
98
+ # Task is still running, wait and retry
99
+ time.sleep(5) # Wait 5 seconds before polling again
100
+
101
+ def start_app():
102
+ import argparse
103
+ parser = argparse.ArgumentParser(description="Wan2.2-Animate 视频生成工具")
104
+ args = parser.parse_args()
105
+
106
+ url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis/"
107
+ # url = "https://poc-dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis"
108
+
109
+ get_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/"
110
+ # get_url = f"https://poc-dashscope.aliyuncs.com/api/v1/tasks"
111
+ app = WanAnimateApp(url=url, get_url=get_url)
112
+
113
+ with gr.Blocks(title="Wan2.2-Animate 视频生成") as demo:
114
+ gr.HTML("""
115
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
116
+ Wan2.2-Animate
117
+ </div>
118
+ """)
119
+ gr.Markdown("基于参考图像和骨骼序列的人物驱动和替换视频生成")
120
+
121
+ with gr.Row():
122
+ with gr.Column():
123
+ ref_img = gr.Image(
124
+ label="Reference Image(参考图像)",
125
+ type="filepath",
126
+ sources=["upload"],
127
+ )
128
+
129
+ video = gr.Video(
130
+ label="Template Video(模版视频)",
131
+ sources=["upload"],
132
+ )
133
+
134
+ with gr.Row():
135
+ model_id = gr.Dropdown(
136
+ label="模型名称",
137
+ choices=["wan2.2-animate-move", "wan2.2-animate-mix"],
138
+ value="wan2.2-animate-move",
139
+ info="支持mov和mix模型"
140
+ )
141
+
142
+ model = gr.Dropdown(
143
+ label="模式",
144
+ choices=["wan-pro", "wan-std"],
145
+ value="wan-pro",
146
+ info="支持标准模型std和专业模式pro两个版本"
147
+ )
148
+
149
+ run_button = gr.Button("Generate Video(生成视频)")
150
+
151
+ with gr.Column():
152
+ output_video = gr.Video(label="Output Video(输出视频)")
153
+ output_status = gr.Textbox(label="Status")
154
+
155
+ run_button.click(
156
+ fn=app.predict,
157
+ inputs=[
158
+ ref_img,
159
+ video,
160
+ model_id,
161
+ model,
162
+ ],
163
+ outputs=[output_video, output_status],
164
+ )
165
+
166
+ # examples_dir = "examples"
167
+ # if os.path.exists(examples_dir):
168
+ # example_data = []
169
+
170
+ # files_dict = {}
171
+ # for file in os.listdir(examples_dir):
172
+ # file_path = os.path.join(examples_dir, file)
173
+ # name, ext = os.path.splitext(file)
174
+
175
+ # if ext.lower() in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"]:
176
+ # if name not in files_dict:
177
+ # files_dict[name] = {}
178
+ # files_dict[name]["image"] = file_path
179
+ # elif ext.lower() in [".mp3", ".wav"]:
180
+ # if name not in files_dict:
181
+ # files_dict[name] = {}
182
+ # files_dict[name]["audio"] = file_path
183
+
184
+ # for name, files in files_dict.items():
185
+ # if "image" in files and "audio" in files:
186
+ # example_data.append([
187
+ # files["image"],
188
+ # files["audio"],
189
+ # "480P"
190
+ # ])
191
+
192
+ # if example_data:
193
+ # gr.Examples(
194
+ # examples=example_data,
195
+ # inputs=[ref_img, video, resolution],
196
+ # outputs=output_video,
197
+ # fn=app.predict,
198
+ # cache_examples=False,
199
+ # )
200
+
201
+ demo.launch()
202
+
203
+
204
+ if __name__ == "__main__":
205
+ start_app()