hzruo commited on
Commit
102a6d4
·
verified ·
1 Parent(s): 67328de

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +158 -109
main.py CHANGED
@@ -3,14 +3,16 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
  from fastapi.responses import StreamingResponse
4
  from fastapi.background import BackgroundTasks
5
  import requests
 
6
  import uuid
7
  import json
8
  import time
9
  from typing import Optional
10
  import asyncio
11
- from curl_cffi import requests as cffi_requests
12
- import re
13
  import os
 
14
 
15
  app = FastAPI()
16
  security = HTTPBearer()
@@ -73,44 +75,49 @@ async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(securi
73
  return token.replace("Bearer ", "") if token.startswith("Bearer ") else token
74
 
75
  async def check_image_status(session: requests.Session, job_id: str, headers: dict) -> Optional[str]:
76
- """
77
- 检查图片生成状态并获取生成的图片
78
-
79
- Args:
80
- session: 请求会话
81
- job_id: 任务ID
82
- headers: 请求头
83
-
84
- Returns:
85
- Optional[str]: base64格式的图片数据,如果生成失败则返回None
86
- """
87
- max_retries = 30 # 最多等待30秒
88
- for _ in range(max_retries):
89
  try:
 
90
  response = session.get(
91
  f'https://chat.akash.network/api/image-status?ids={job_id}',
92
  headers=headers
93
  )
 
94
  status_data = response.json()
95
 
96
  if status_data and isinstance(status_data, list) and len(status_data) > 0:
97
  job_info = status_data[0]
 
 
98
 
99
- # 如果result不为空,说明图片已生成
100
- if job_info.get("result"):
101
- return job_info["result"] # 直接返回base64数据
102
-
103
- # 如果状态是失败,则停止等待
104
- if job_info.get("status") == "failed":
105
- print(f"Image generation failed for job {job_id}")
 
 
 
 
 
 
 
 
106
  return None
 
 
 
 
107
 
108
  except Exception as e:
109
- print(f"Error checking image status: {e}")
110
-
111
- await asyncio.sleep(1) # 等待1秒后重试
112
 
113
- print(f"Timeout waiting for image generation job {job_id}")
114
  return None
115
 
116
  @app.get("/")
@@ -184,88 +191,34 @@ async def chat_completions(
184
  # 在处理消息时先判断模型类型
185
  if data.get('model') == 'AkashGen' and "<image_generation>" in msg_data:
186
  # 图片生成模型的特殊处理
187
- match = re.search(r"jobId='([^']+)' prompt='([^']+)' negative='([^']*)'", msg_data)
188
- if match:
189
- job_id, prompt, negative = match.groups()
190
- print(f"Starting image generation process for job_id: {job_id}")
191
-
192
- # 立即发送思考开始的消息
193
- start_time = time.time()
194
- think_msg = "<think>\n"
195
- think_msg += "🎨 Generating image...\n\n"
196
- think_msg += f"Prompt: {prompt}\n"
197
-
198
- # 发送思考开始消息 (使用标准 OpenAI 格式)
199
- chunk = {
200
- "id": f"chatcmpl-{chat_id}",
201
- "object": "chat.completion.chunk",
202
- "created": int(time.time()),
203
- "model": data.get('model'), # 使用请求中指定的模型
204
- "choices": [{
205
- "delta": {"content": think_msg},
206
- "index": 0,
207
- "finish_reason": None
208
- }]
209
- }
210
- yield f"data: {json.dumps(chunk)}\n\n"
211
-
212
- # 同步方式检查图片状态
213
- max_retries = 10
214
- retry_interval = 3
215
- result = None
216
-
217
- for attempt in range(max_retries):
218
- try:
219
- print(f"\nAttempt {attempt + 1}/{max_retries} for job {job_id}")
220
- status_response = cffi_requests.get(
221
- f'https://chat.akash.network/api/image-status?ids={job_id}',
222
- headers=headers,
223
- impersonate="chrome110"
224
- )
225
- print(f"Status response code: {status_response.status_code}")
226
- status_data = status_response.json()
227
- print(f"Status data: {json.dumps(status_data, indent=2)}")
228
-
229
- if status_data and isinstance(status_data, list) and len(status_data) > 0:
230
- job_info = status_data[0]
231
- print(f"Job status: {job_info.get('status')}")
232
-
233
- if job_info.get("result"):
234
- result = job_info['result']
235
- if result and not result.startswith("Failed"):
236
- break
237
- elif job_info.get("status") == "failed":
238
- result = None
239
- break
240
- except Exception as e:
241
- print(f"Error checking status: {e}")
242
-
243
- if attempt < max_retries - 1:
244
- time.sleep(retry_interval)
245
-
246
- # 发送结束消息
247
- elapsed_time = time.time() - start_time
248
- end_msg = f"\n🤔 Thinking for {elapsed_time:.1f}s...\n"
249
- end_msg += "</think>\n\n"
250
- if result and not result.startswith("Failed"):
251
- end_msg += f"![Generated Image]({result})"
252
- else:
253
- end_msg += "*Image generation failed or timed out.*\n"
254
-
255
- # 发送结束消息 (使用标准 OpenAI 格式)
256
- chunk = {
257
- "id": f"chatcmpl-{chat_id}",
258
- "object": "chat.completion.chunk",
259
- "created": int(time.time()),
260
- "model": data.get('model'), # 使用请求中指定的模型
261
- "choices": [{
262
- "delta": {"content": end_msg},
263
- "index": 0,
264
- "finish_reason": None
265
- }]
266
- }
267
- yield f"data: {json.dumps(chunk)}\n\n"
268
- continue
269
 
270
  content_buffer += msg_data
271
 
@@ -373,6 +326,102 @@ async def list_models(api_key: str = Depends(get_api_key)):
373
  print(f"Error in list_models: {e}")
374
  return {"error": str(e)}
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  if __name__ == '__main__':
377
  import uvicorn
378
- uvicorn.run(app, host='0.0.0.0', port=7860)
 
3
  from fastapi.responses import StreamingResponse
4
  from fastapi.background import BackgroundTasks
5
  import requests
6
+ from curl_cffi import requests as cffi_requests # 保留这个,用于获取cookies
7
  import uuid
8
  import json
9
  import time
10
  from typing import Optional
11
  import asyncio
12
+ import base64
13
+ import tempfile
14
  import os
15
+ import re
16
 
17
  app = FastAPI()
18
  security = HTTPBearer()
 
75
  return token.replace("Bearer ", "") if token.startswith("Bearer ") else token
76
 
77
  async def check_image_status(session: requests.Session, job_id: str, headers: dict) -> Optional[str]:
78
+ """检查图片生成状态并获取生成的图片"""
79
+ max_retries = 30
80
+ for attempt in range(max_retries):
 
 
 
 
 
 
 
 
 
 
81
  try:
82
+ print(f"\nAttempt {attempt + 1}/{max_retries} for job {job_id}")
83
  response = session.get(
84
  f'https://chat.akash.network/api/image-status?ids={job_id}',
85
  headers=headers
86
  )
87
+ print(f"Status response code: {response.status_code}")
88
  status_data = response.json()
89
 
90
  if status_data and isinstance(status_data, list) and len(status_data) > 0:
91
  job_info = status_data[0]
92
+ status = job_info.get('status')
93
+ print(f"Job status: {status}")
94
 
95
+ # 只有当状态为 completed 时才处理结果
96
+ if status == "completed":
97
+ result = job_info.get("result")
98
+ if result and not result.startswith("Failed"):
99
+ print("Got valid result, attempting upload...")
100
+ image_url = await upload_to_xinyew(result, job_id)
101
+ if image_url:
102
+ print(f"Successfully uploaded image: {image_url}")
103
+ return image_url
104
+ print("Image upload failed")
105
+ return None
106
+ print("Invalid result received")
107
+ return None
108
+ elif status == "failed":
109
+ print(f"Job {job_id} failed")
110
  return None
111
+
112
+ # 如果状态是其他(如 pending),继续等待
113
+ await asyncio.sleep(1)
114
+ continue
115
 
116
  except Exception as e:
117
+ print(f"Error checking status: {e}")
118
+ return None
 
119
 
120
+ print(f"Timeout waiting for job {job_id}")
121
  return None
122
 
123
  @app.get("/")
 
191
  # 在处理消息时先判断模型类型
192
  if data.get('model') == 'AkashGen' and "<image_generation>" in msg_data:
193
  # 图片生成模型的特殊处理
194
+ async def process_and_send():
195
+ end_msg = await process_image_generation(msg_data, session, headers, chat_id)
196
+ if end_msg:
197
+ chunk = {
198
+ "id": f"chatcmpl-{chat_id}",
199
+ "object": "chat.completion.chunk",
200
+ "created": int(time.time()),
201
+ "model": data.get('model'),
202
+ "choices": [{
203
+ "delta": {"content": end_msg},
204
+ "index": 0,
205
+ "finish_reason": None
206
+ }]
207
+ }
208
+ return f"data: {json.dumps(chunk)}\n\n"
209
+ return None
210
+
211
+ # 创建新的事件循环
212
+ loop = asyncio.new_event_loop()
213
+ asyncio.set_event_loop(loop)
214
+ try:
215
+ result = loop.run_until_complete(process_and_send())
216
+ finally:
217
+ loop.close()
218
+
219
+ if result:
220
+ yield result
221
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  content_buffer += msg_data
224
 
 
326
  print(f"Error in list_models: {e}")
327
  return {"error": str(e)}
328
 
329
+ async def upload_to_xinyew(image_base64: str, job_id: str) -> Optional[str]:
330
+ """上传图片到新野图床并返回URL"""
331
+ try:
332
+ print(f"\n=== Starting image upload for job {job_id} ===")
333
+ print(f"Base64 data length: {len(image_base64)}")
334
+
335
+ # 解码base64图片数据
336
+ try:
337
+ image_data = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64)
338
+ print(f"Decoded image data length: {len(image_data)} bytes")
339
+ except Exception as e:
340
+ print(f"Error decoding base64: {e}")
341
+ print(f"First 100 chars of base64: {image_base64[:100]}...")
342
+ return None
343
+
344
+ # 创建临时文件
345
+ with tempfile.NamedTemporaryFile(suffix='.jpeg', delete=False) as temp_file:
346
+ temp_file.write(image_data)
347
+ temp_file_path = temp_file.name
348
+
349
+ try:
350
+ filename = f"{job_id}.jpeg"
351
+ print(f"Using filename: {filename}")
352
+
353
+ # 准备文件上传
354
+ files = {
355
+ 'file': (filename, open(temp_file_path, 'rb'), 'image/jpeg')
356
+ }
357
+
358
+ print("Sending request to xinyew.cn...")
359
+ response = requests.post(
360
+ 'https://api.xinyew.cn/api/jdtc',
361
+ files=files,
362
+ timeout=30
363
+ )
364
+
365
+ print(f"Upload response status: {response.status_code}")
366
+ if response.status_code == 200:
367
+ result = response.json()
368
+ print(f"Upload response: {result}")
369
+
370
+ if result.get('errno') == 0:
371
+ url = result.get('data', {}).get('url')
372
+ if url:
373
+ print(f"Successfully got image URL: {url}")
374
+ return url
375
+ print("No URL in response data")
376
+ else:
377
+ print(f"Upload failed: {result.get('message')}")
378
+ else:
379
+ print(f"Upload failed with status {response.status_code}")
380
+ print(f"Response content: {response.text}")
381
+ return None
382
+
383
+ finally:
384
+ # 清理临时文件
385
+ try:
386
+ os.unlink(temp_file_path)
387
+ except Exception as e:
388
+ print(f"Error removing temp file: {e}")
389
+
390
+ except Exception as e:
391
+ print(f"Error in upload_to_xinyew: {e}")
392
+ import traceback
393
+ print(traceback.format_exc())
394
+ return None
395
+
396
+ async def process_image_generation(msg_data: str, session: requests.Session, headers: dict, chat_id: str) -> str:
397
+ """处理图片生成的逻辑"""
398
+ match = re.search(r"jobId='([^']+)' prompt='([^']+)' negative='([^']*)'", msg_data)
399
+ if match:
400
+ job_id, prompt, negative = match.groups()
401
+ print(f"Starting image generation process for job_id: {job_id}")
402
+
403
+ # 发送思考开始的消息
404
+ start_time = time.time()
405
+ end_msg = "<think>\n"
406
+ end_msg += "🎨 Generating image...\n\n"
407
+ end_msg += f"Prompt: {prompt}\n"
408
+
409
+ # 检查图片状态和上传
410
+ result = await check_image_status(session, job_id, headers)
411
+
412
+ # 发送结束消息
413
+ elapsed_time = time.time() - start_time
414
+ end_msg += f"\n🤔 Thinking for {elapsed_time:.1f}s...\n"
415
+ end_msg += "</think>\n\n"
416
+
417
+ if result: # result 现在是上传后的图片URL
418
+ end_msg += f"![Generated Image]({result})"
419
+ else:
420
+ end_msg += "*Image generation or upload failed.*\n"
421
+
422
+ return end_msg
423
+ return ""
424
+
425
  if __name__ == '__main__':
426
  import uvicorn
427
+ uvicorn.run(app, host='0.0.0.0', port=9000)