|
from fastapi import FastAPI, Request, Response, UploadFile, File |
|
from fastapi.responses import StreamingResponse, FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
import httpx |
|
import json |
|
import asyncio |
|
import time |
|
import base64 |
|
from typing import Optional, Dict, Any, List |
|
from io import BytesIO |
|
|
|
|
|
QWEN_API_URL = "https://chat.qwenlm.ai/api/chat/completions" |
|
QWEN_MODELS_URL = "https://chat.qwenlm.ai/api/models" |
|
QWEN_FILES_URL = "https://chat.qwenlm.ai/api/v1/files/" |
|
MAX_RETRIES = 3 |
|
RETRY_DELAY = 1 |
|
|
|
|
|
cached_models = None |
|
cached_models_timestamp = 0 |
|
CACHE_TTL = 60 * 60 |
|
|
|
app = FastAPI() |
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
client = httpx.AsyncClient() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return FileResponse("index.html") |
|
|
|
async def sleep(seconds: float): |
|
await asyncio.sleep(seconds) |
|
|
|
|
|
async def base64_to_file(base64_str: str) -> BytesIO: |
|
try: |
|
|
|
if ',' in base64_str: |
|
base64_str = base64_str.split(',', 1)[1] |
|
|
|
|
|
image_data = base64.b64decode(base64_str) |
|
return BytesIO(image_data) |
|
except Exception as e: |
|
raise Exception(f"Failed to convert base64 to file: {str(e)}") |
|
|
|
|
|
async def upload_image_to_qwen(auth_header: str, image_data: BytesIO) -> str: |
|
try: |
|
files = {'file': ('image.jpg', image_data, 'image/jpeg')} |
|
headers = { |
|
"Authorization": auth_header, |
|
"accept": "application/json" |
|
} |
|
|
|
async with httpx.AsyncClient() as client: |
|
response = await client.post( |
|
QWEN_FILES_URL, |
|
headers=headers, |
|
files=files |
|
) |
|
|
|
if response.is_success: |
|
data = response.json() |
|
if not data.get('id'): |
|
raise Exception("File upload failed: No valid file ID returned") |
|
return data['id'] |
|
else: |
|
raise Exception(f"File upload failed with status {response.status_code}") |
|
|
|
except Exception as e: |
|
raise Exception(f"Failed to upload image: {str(e)}") |
|
|
|
|
|
async def process_messages(messages: List[Dict], auth_header: str) -> List[Dict]: |
|
processed_messages = [] |
|
|
|
for message in messages: |
|
if isinstance(message.get('content'), list): |
|
new_content = [] |
|
for content in message['content']: |
|
if (content.get('type') == 'image_url' and |
|
content.get('image_url', {}).get('url', '').startswith('data:')): |
|
|
|
image_data = await base64_to_file(content['image_url']['url']) |
|
image_id = await upload_image_to_qwen(auth_header, image_data) |
|
new_content.append({ |
|
'type': 'image', |
|
'image': image_id |
|
}) |
|
else: |
|
new_content.append(content) |
|
message['content'] = new_content |
|
processed_messages.append(message) |
|
|
|
return processed_messages |
|
|
|
async def fetch_with_retry(url: str, options: Dict, retries: int = MAX_RETRIES): |
|
last_error = None |
|
|
|
for i in range(retries): |
|
try: |
|
response = await client.request( |
|
method=options.get("method", "GET"), |
|
url=url, |
|
headers=options.get("headers", {}), |
|
json=options.get("json"), |
|
) |
|
|
|
if response.is_success: |
|
return response |
|
|
|
content_type = response.headers.get("content-type", "") |
|
if response.status_code >= 500 or "text/html" in content_type: |
|
last_error = { |
|
"status": response.status_code, |
|
"content_type": content_type, |
|
"response_text": response.text[:1000], |
|
"headers": dict(response.headers) |
|
} |
|
|
|
if i < retries - 1: |
|
await sleep(RETRY_DELAY * (i + 1)) |
|
continue |
|
else: |
|
last_error = { |
|
"status": response.status_code, |
|
"headers": dict(response.headers) |
|
} |
|
break |
|
|
|
except Exception as error: |
|
last_error = error |
|
if i < retries - 1: |
|
await sleep(RETRY_DELAY * (i + 1)) |
|
continue |
|
|
|
raise Exception(json.dumps({ |
|
"error": True, |
|
"message": "All retry attempts failed", |
|
"last_error": str(last_error), |
|
"retries": retries |
|
})) |
|
|
|
async def process_line(line: str, previous_content: str) -> tuple[str, Optional[dict]]: |
|
try: |
|
data = json.loads(line[6:]) |
|
if (data.get("choices") and data["choices"][0].get("delta")): |
|
delta = data["choices"][0]["delta"] |
|
current_content = delta.get("content", "") |
|
|
|
|
|
if previous_content and current_content: |
|
if current_content.startswith(previous_content): |
|
new_content = current_content[len(previous_content):] |
|
else: |
|
new_content = current_content |
|
else: |
|
new_content = current_content |
|
|
|
|
|
new_data = { |
|
"choices": [{ |
|
"delta": { |
|
"role": delta.get("role", "assistant"), |
|
"content": new_content |
|
} |
|
}] |
|
} |
|
|
|
return current_content, new_data |
|
return previous_content, data |
|
except: |
|
return previous_content, None |
|
|
|
async def stream_generator(response: httpx.Response): |
|
buffer = "" |
|
previous_content = "" |
|
|
|
async for chunk in response.aiter_bytes(): |
|
chunk_text = chunk.decode() |
|
buffer += chunk_text |
|
|
|
lines = buffer.split("\n") |
|
buffer = lines.pop() if lines else "" |
|
|
|
for line in lines: |
|
line = line.strip() |
|
if line.startswith("data: "): |
|
previous_content, data = await process_line(line, previous_content) |
|
if data: |
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
|
if buffer: |
|
previous_content, data = await process_line(buffer, previous_content) |
|
if data: |
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
@app.get("/healthz") |
|
async def health_check(): |
|
return {"status": "ok"} |
|
|
|
@app.get("/api/models") |
|
async def get_models(request: Request): |
|
global cached_models, cached_models_timestamp |
|
|
|
auth_header = request.headers.get("Authorization") |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
return Response(status_code=401, content="Unauthorized") |
|
|
|
now = time.time() |
|
if cached_models and now - cached_models_timestamp < CACHE_TTL: |
|
return Response( |
|
content=cached_models, |
|
media_type="application/json" |
|
) |
|
|
|
try: |
|
response = await fetch_with_retry( |
|
QWEN_MODELS_URL, |
|
{"headers": {"Authorization": auth_header}} |
|
) |
|
|
|
cached_models = response.text |
|
cached_models_timestamp = now |
|
|
|
return Response( |
|
content=cached_models, |
|
media_type="application/json" |
|
) |
|
except Exception as error: |
|
return Response( |
|
content=json.dumps({"error": True, "message": str(error)}), |
|
status_code=500 |
|
) |
|
|
|
@app.post("/api/chat/completions") |
|
async def chat_completions(request: Request): |
|
auth_header = request.headers.get("Authorization") |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
return Response(status_code=401, content="Unauthorized") |
|
|
|
request_data = await request.json() |
|
messages = request_data.get("messages") |
|
stream = request_data.get("stream", False) |
|
model = request_data.get("model") |
|
max_tokens = request_data.get("max_tokens") |
|
|
|
if not model: |
|
return Response( |
|
content=json.dumps({"error": True, "message": "Model parameter is required"}), |
|
status_code=400 |
|
) |
|
|
|
try: |
|
|
|
processed_messages = await process_messages(messages, auth_header) |
|
|
|
qwen_request = { |
|
"model": model, |
|
"messages": processed_messages, |
|
"stream": stream |
|
} |
|
|
|
if max_tokens is not None: |
|
qwen_request["max_tokens"] = max_tokens |
|
|
|
response = await client.post( |
|
QWEN_API_URL, |
|
headers={ |
|
"Content-Type": "application/json", |
|
"Authorization": auth_header |
|
}, |
|
json=qwen_request |
|
) |
|
|
|
if stream: |
|
return StreamingResponse( |
|
stream_generator(response), |
|
media_type="text/event-stream" |
|
) |
|
|
|
return Response( |
|
content=response.text, |
|
media_type="application/json" |
|
) |
|
|
|
except Exception as error: |
|
return Response( |
|
content=json.dumps({"error": True, "message": str(error)}), |
|
status_code=500 |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|