Update main.py
Browse files
main.py
CHANGED
@@ -3,14 +3,16 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
from fastapi.background import BackgroundTasks
|
5 |
import requests
|
|
|
6 |
import uuid
|
7 |
import json
|
8 |
import time
|
9 |
from typing import Optional
|
10 |
import asyncio
|
11 |
-
|
12 |
-
import
|
13 |
import os
|
|
|
14 |
|
15 |
app = FastAPI()
|
16 |
security = HTTPBearer()
|
@@ -73,44 +75,49 @@ async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(securi
|
|
73 |
return token.replace("Bearer ", "") if token.startswith("Bearer ") else token
|
74 |
|
75 |
async def check_image_status(session: requests.Session, job_id: str, headers: dict) -> Optional[str]:
|
76 |
-
"""
|
77 |
-
|
78 |
-
|
79 |
-
Args:
|
80 |
-
session: 请求会话
|
81 |
-
job_id: 任务ID
|
82 |
-
headers: 请求头
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
Optional[str]: base64格式的图片数据,如果生成失败则返回None
|
86 |
-
"""
|
87 |
-
max_retries = 30 # 最多等待30秒
|
88 |
-
for _ in range(max_retries):
|
89 |
try:
|
|
|
90 |
response = session.get(
|
91 |
f'https://chat.akash.network/api/image-status?ids={job_id}',
|
92 |
headers=headers
|
93 |
)
|
|
|
94 |
status_data = response.json()
|
95 |
|
96 |
if status_data and isinstance(status_data, list) and len(status_data) > 0:
|
97 |
job_info = status_data[0]
|
|
|
|
|
98 |
|
99 |
-
#
|
100 |
-
if
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
return None
|
|
|
|
|
|
|
|
|
107 |
|
108 |
except Exception as e:
|
109 |
-
print(f"Error checking
|
110 |
-
|
111 |
-
await asyncio.sleep(1) # 等待1秒后重试
|
112 |
|
113 |
-
print(f"Timeout waiting for
|
114 |
return None
|
115 |
|
116 |
@app.get("/")
|
@@ -184,88 +191,34 @@ async def chat_completions(
|
|
184 |
# 在处理消息时先判断模型类型
|
185 |
if data.get('model') == 'AkashGen' and "<image_generation>" in msg_data:
|
186 |
# 图片生成模型的特殊处理
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
"
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
result = None
|
216 |
-
|
217 |
-
for attempt in range(max_retries):
|
218 |
-
try:
|
219 |
-
print(f"\nAttempt {attempt + 1}/{max_retries} for job {job_id}")
|
220 |
-
status_response = cffi_requests.get(
|
221 |
-
f'https://chat.akash.network/api/image-status?ids={job_id}',
|
222 |
-
headers=headers,
|
223 |
-
impersonate="chrome110"
|
224 |
-
)
|
225 |
-
print(f"Status response code: {status_response.status_code}")
|
226 |
-
status_data = status_response.json()
|
227 |
-
print(f"Status data: {json.dumps(status_data, indent=2)}")
|
228 |
-
|
229 |
-
if status_data and isinstance(status_data, list) and len(status_data) > 0:
|
230 |
-
job_info = status_data[0]
|
231 |
-
print(f"Job status: {job_info.get('status')}")
|
232 |
-
|
233 |
-
if job_info.get("result"):
|
234 |
-
result = job_info['result']
|
235 |
-
if result and not result.startswith("Failed"):
|
236 |
-
break
|
237 |
-
elif job_info.get("status") == "failed":
|
238 |
-
result = None
|
239 |
-
break
|
240 |
-
except Exception as e:
|
241 |
-
print(f"Error checking status: {e}")
|
242 |
-
|
243 |
-
if attempt < max_retries - 1:
|
244 |
-
time.sleep(retry_interval)
|
245 |
-
|
246 |
-
# 发送结束消息
|
247 |
-
elapsed_time = time.time() - start_time
|
248 |
-
end_msg = f"\n🤔 Thinking for {elapsed_time:.1f}s...\n"
|
249 |
-
end_msg += "</think>\n\n"
|
250 |
-
if result and not result.startswith("Failed"):
|
251 |
-
end_msg += f""
|
252 |
-
else:
|
253 |
-
end_msg += "*Image generation failed or timed out.*\n"
|
254 |
-
|
255 |
-
# 发送结束消息 (使用标准 OpenAI 格式)
|
256 |
-
chunk = {
|
257 |
-
"id": f"chatcmpl-{chat_id}",
|
258 |
-
"object": "chat.completion.chunk",
|
259 |
-
"created": int(time.time()),
|
260 |
-
"model": data.get('model'), # 使用请求中指定的模型
|
261 |
-
"choices": [{
|
262 |
-
"delta": {"content": end_msg},
|
263 |
-
"index": 0,
|
264 |
-
"finish_reason": None
|
265 |
-
}]
|
266 |
-
}
|
267 |
-
yield f"data: {json.dumps(chunk)}\n\n"
|
268 |
-
continue
|
269 |
|
270 |
content_buffer += msg_data
|
271 |
|
@@ -373,6 +326,102 @@ async def list_models(api_key: str = Depends(get_api_key)):
|
|
373 |
print(f"Error in list_models: {e}")
|
374 |
return {"error": str(e)}
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
if __name__ == '__main__':
|
377 |
import uvicorn
|
378 |
-
uvicorn.run(app, host='0.0.0.0', port=
|
|
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
from fastapi.background import BackgroundTasks
|
5 |
import requests
|
6 |
+
from curl_cffi import requests as cffi_requests # 保留这个,用于获取cookies
|
7 |
import uuid
|
8 |
import json
|
9 |
import time
|
10 |
from typing import Optional
|
11 |
import asyncio
|
12 |
+
import base64
|
13 |
+
import tempfile
|
14 |
import os
|
15 |
+
import re
|
16 |
|
17 |
app = FastAPI()
|
18 |
security = HTTPBearer()
|
|
|
75 |
return token.replace("Bearer ", "") if token.startswith("Bearer ") else token
|
76 |
|
77 |
async def check_image_status(session: requests.Session, job_id: str, headers: dict) -> Optional[str]:
|
78 |
+
"""检查图片生成状态并获取生成的图片"""
|
79 |
+
max_retries = 30
|
80 |
+
for attempt in range(max_retries):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
try:
|
82 |
+
print(f"\nAttempt {attempt + 1}/{max_retries} for job {job_id}")
|
83 |
response = session.get(
|
84 |
f'https://chat.akash.network/api/image-status?ids={job_id}',
|
85 |
headers=headers
|
86 |
)
|
87 |
+
print(f"Status response code: {response.status_code}")
|
88 |
status_data = response.json()
|
89 |
|
90 |
if status_data and isinstance(status_data, list) and len(status_data) > 0:
|
91 |
job_info = status_data[0]
|
92 |
+
status = job_info.get('status')
|
93 |
+
print(f"Job status: {status}")
|
94 |
|
95 |
+
# 只有当状态为 completed 时才处理结果
|
96 |
+
if status == "completed":
|
97 |
+
result = job_info.get("result")
|
98 |
+
if result and not result.startswith("Failed"):
|
99 |
+
print("Got valid result, attempting upload...")
|
100 |
+
image_url = await upload_to_xinyew(result, job_id)
|
101 |
+
if image_url:
|
102 |
+
print(f"Successfully uploaded image: {image_url}")
|
103 |
+
return image_url
|
104 |
+
print("Image upload failed")
|
105 |
+
return None
|
106 |
+
print("Invalid result received")
|
107 |
+
return None
|
108 |
+
elif status == "failed":
|
109 |
+
print(f"Job {job_id} failed")
|
110 |
return None
|
111 |
+
|
112 |
+
# 如果状态是其他(如 pending),继续等待
|
113 |
+
await asyncio.sleep(1)
|
114 |
+
continue
|
115 |
|
116 |
except Exception as e:
|
117 |
+
print(f"Error checking status: {e}")
|
118 |
+
return None
|
|
|
119 |
|
120 |
+
print(f"Timeout waiting for job {job_id}")
|
121 |
return None
|
122 |
|
123 |
@app.get("/")
|
|
|
191 |
# 在处理消息时先判断模型类型
|
192 |
if data.get('model') == 'AkashGen' and "<image_generation>" in msg_data:
|
193 |
# 图片生成模型的特殊处理
|
194 |
+
async def process_and_send():
|
195 |
+
end_msg = await process_image_generation(msg_data, session, headers, chat_id)
|
196 |
+
if end_msg:
|
197 |
+
chunk = {
|
198 |
+
"id": f"chatcmpl-{chat_id}",
|
199 |
+
"object": "chat.completion.chunk",
|
200 |
+
"created": int(time.time()),
|
201 |
+
"model": data.get('model'),
|
202 |
+
"choices": [{
|
203 |
+
"delta": {"content": end_msg},
|
204 |
+
"index": 0,
|
205 |
+
"finish_reason": None
|
206 |
+
}]
|
207 |
+
}
|
208 |
+
return f"data: {json.dumps(chunk)}\n\n"
|
209 |
+
return None
|
210 |
+
|
211 |
+
# 创建新的事件循环
|
212 |
+
loop = asyncio.new_event_loop()
|
213 |
+
asyncio.set_event_loop(loop)
|
214 |
+
try:
|
215 |
+
result = loop.run_until_complete(process_and_send())
|
216 |
+
finally:
|
217 |
+
loop.close()
|
218 |
+
|
219 |
+
if result:
|
220 |
+
yield result
|
221 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
content_buffer += msg_data
|
224 |
|
|
|
326 |
print(f"Error in list_models: {e}")
|
327 |
return {"error": str(e)}
|
328 |
|
329 |
+
async def upload_to_xinyew(image_base64: str, job_id: str) -> Optional[str]:
|
330 |
+
"""上传图片到新野图床并返回URL"""
|
331 |
+
try:
|
332 |
+
print(f"\n=== Starting image upload for job {job_id} ===")
|
333 |
+
print(f"Base64 data length: {len(image_base64)}")
|
334 |
+
|
335 |
+
# 解码base64图片数据
|
336 |
+
try:
|
337 |
+
image_data = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64)
|
338 |
+
print(f"Decoded image data length: {len(image_data)} bytes")
|
339 |
+
except Exception as e:
|
340 |
+
print(f"Error decoding base64: {e}")
|
341 |
+
print(f"First 100 chars of base64: {image_base64[:100]}...")
|
342 |
+
return None
|
343 |
+
|
344 |
+
# 创建临时文件
|
345 |
+
with tempfile.NamedTemporaryFile(suffix='.jpeg', delete=False) as temp_file:
|
346 |
+
temp_file.write(image_data)
|
347 |
+
temp_file_path = temp_file.name
|
348 |
+
|
349 |
+
try:
|
350 |
+
filename = f"{job_id}.jpeg"
|
351 |
+
print(f"Using filename: {filename}")
|
352 |
+
|
353 |
+
# 准备文件上传
|
354 |
+
files = {
|
355 |
+
'file': (filename, open(temp_file_path, 'rb'), 'image/jpeg')
|
356 |
+
}
|
357 |
+
|
358 |
+
print("Sending request to xinyew.cn...")
|
359 |
+
response = requests.post(
|
360 |
+
'https://api.xinyew.cn/api/jdtc',
|
361 |
+
files=files,
|
362 |
+
timeout=30
|
363 |
+
)
|
364 |
+
|
365 |
+
print(f"Upload response status: {response.status_code}")
|
366 |
+
if response.status_code == 200:
|
367 |
+
result = response.json()
|
368 |
+
print(f"Upload response: {result}")
|
369 |
+
|
370 |
+
if result.get('errno') == 0:
|
371 |
+
url = result.get('data', {}).get('url')
|
372 |
+
if url:
|
373 |
+
print(f"Successfully got image URL: {url}")
|
374 |
+
return url
|
375 |
+
print("No URL in response data")
|
376 |
+
else:
|
377 |
+
print(f"Upload failed: {result.get('message')}")
|
378 |
+
else:
|
379 |
+
print(f"Upload failed with status {response.status_code}")
|
380 |
+
print(f"Response content: {response.text}")
|
381 |
+
return None
|
382 |
+
|
383 |
+
finally:
|
384 |
+
# 清理临时文件
|
385 |
+
try:
|
386 |
+
os.unlink(temp_file_path)
|
387 |
+
except Exception as e:
|
388 |
+
print(f"Error removing temp file: {e}")
|
389 |
+
|
390 |
+
except Exception as e:
|
391 |
+
print(f"Error in upload_to_xinyew: {e}")
|
392 |
+
import traceback
|
393 |
+
print(traceback.format_exc())
|
394 |
+
return None
|
395 |
+
|
396 |
+
async def process_image_generation(msg_data: str, session: requests.Session, headers: dict, chat_id: str) -> str:
|
397 |
+
"""处理图片生成的逻辑"""
|
398 |
+
match = re.search(r"jobId='([^']+)' prompt='([^']+)' negative='([^']*)'", msg_data)
|
399 |
+
if match:
|
400 |
+
job_id, prompt, negative = match.groups()
|
401 |
+
print(f"Starting image generation process for job_id: {job_id}")
|
402 |
+
|
403 |
+
# 发送思考开始的消息
|
404 |
+
start_time = time.time()
|
405 |
+
end_msg = "<think>\n"
|
406 |
+
end_msg += "🎨 Generating image...\n\n"
|
407 |
+
end_msg += f"Prompt: {prompt}\n"
|
408 |
+
|
409 |
+
# 检查图片状态和上传
|
410 |
+
result = await check_image_status(session, job_id, headers)
|
411 |
+
|
412 |
+
# 发送结束消息
|
413 |
+
elapsed_time = time.time() - start_time
|
414 |
+
end_msg += f"\n🤔 Thinking for {elapsed_time:.1f}s...\n"
|
415 |
+
end_msg += "</think>\n\n"
|
416 |
+
|
417 |
+
if result: # result 现在是上传后的图片URL
|
418 |
+
end_msg += f""
|
419 |
+
else:
|
420 |
+
end_msg += "*Image generation or upload failed.*\n"
|
421 |
+
|
422 |
+
return end_msg
|
423 |
+
return ""
|
424 |
+
|
425 |
if __name__ == '__main__':
|
426 |
import uvicorn
|
427 |
+
uvicorn.run(app, host='0.0.0.0', port=9000)
|