|
|
|
model_mapping = { |
|
"sonnet": "CLAUDE_V3_5_SONNET", |
|
"4o": "OPENAI_GPT4O", |
|
"32": "OPENAI_GPT4_32K", |
|
"turbo": "OPENAI_GPT4_128K_LATEST", |
|
"vision": "OPENAI_GPT4_VISION", |
|
"3.5": "OPENAI_GPT3_5", |
|
"opus": "CLAUDE_V3_OPUS", |
|
"haiku": "CLAUDE_V3_HAIKU", |
|
"claude-2": "CLAUDE_V2_1", |
|
"pro": "GEMINI_1_5_PRO", |
|
"palm": "PALM", |
|
"llama": "LLAMA3_LARGE_CHAT", |
|
"_legacy_sonnet": "CLAUDE_V3_SONNET", |
|
"_legacy_gemini": "GEMINI_PRO", |
|
"_legacy_palm": "PALM_TEXT", |
|
"gpt-4o": "OPENAI_GPT4O", |
|
"gpt-4-turbo": "OPENAI_GPT4_128K_LATEST", |
|
"claude-3-opus": "CLAUDE_V3_OPUS" |
|
} |
|
|
|
|
|
import os |
|
set_env = lambda var_name, default=None: environment_variables.update({var_name: os.getenv(var_name, default)}) or os.getenv(var_name, default) |
|
environment_variables = {} |
|
|
|
|
|
FALLBACK_MODEL = set_env("FALLBACK_LLM", "CLAUDE_V3_5_SONNET") |
|
RATE_LIMIT = set_env("RATE_LIMIT", "1/4 second") |
|
LOG_LEVEL = set_env("LOG_LEVEL", "INFO") |
|
PORT = int(set_env("PORT", "8000")) |
|
BASE_HOST = set_env("BASE_HOST", "apps.abacus.ai") |
|
|
|
DEPLOYMENT_CACHE_TTL = 3600 * 24 |
|
IMPERSONATE_BASE = "chrome" |
|
CURL_MAX_CLIENTS = 300 |
|
|
|
import asyncio |
|
import json |
|
import uuid |
|
import random |
|
import logging |
|
from typing import Dict, Any |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import StreamingResponse |
|
from curl_cffi import requests, CurlOpt, CurlHttpVersion |
|
|
|
from cachetools import TTLCache |
|
deployment_cache = TTLCache(maxsize=300, ttl=DEPLOYMENT_CACHE_TTL) |
|
cache_lock = asyncio.Lock() |
|
|
|
import websockets |
|
|
|
try: |
|
import orjson as json |
|
jsonDumps = lambda text: json.dumps(text).decode('utf-8') |
|
except ImportError: |
|
import json |
|
jsonDumps = json.dumps |
|
from slowapi import Limiter, _rate_limit_exceeded_handler |
|
from slowapi.util import get_remote_address |
|
from slowapi.errors import RateLimitExceeded |
|
|
|
CURL_OPTS = { |
|
CurlOpt.TCP_NODELAY: 1, CurlOpt.FORBID_REUSE: 0, CurlOpt.FRESH_CONNECT: 0, CurlOpt.TCP_KEEPALIVE: 1, CurlOpt.MAXAGE_CONN: 30 |
|
} |
|
client = requests.AsyncSession( |
|
impersonate=IMPERSONATE_BASE, default_headers=True, max_clients=CURL_MAX_CLIENTS, curl_options=CURL_OPTS, http_version=CurlHttpVersion.V2_PRIOR_KNOWLEDGE |
|
) |
|
|
|
from rich.logging import RichHandler |
|
from rich.console import Console |
|
from rich.table import Table |
|
|
|
|
|
logging.basicConfig( |
|
level=getattr(logging, LOG_LEVEL), |
|
format="%(message)s", |
|
datefmt="[%X]", |
|
handlers=[RichHandler()] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
limiter = Limiter(key_func=get_remote_address) |
|
app.state.limiter = limiter |
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
|
|
|
def convert_unicode_escape(s): |
|
return s.encode('utf-8').decode('unicode-escape') |
|
|
|
async def make_request(method: str, url: str, headers: dict, data: dict): |
|
try: |
|
response = await client.request(method=method, url=url, headers=headers, json=data) |
|
status = response.status_code |
|
if status == 200: |
|
return response |
|
elif status in (401, 403): |
|
raise HTTPException(status_code=401, detail="Invalid authorization info") |
|
else: |
|
raise HTTPException(status_code=status, detail=f"Network issue: {response.text}") |
|
except Exception as e: |
|
logger.error(f"Request error: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Request error: {str(e)}") |
|
|
|
def map_model(requestModel): |
|
model = requestModel.lower() |
|
|
|
if model.startswith('adv/'): |
|
model = model[4:] |
|
return model if model else FALLBACK_MODEL |
|
|
|
return next((value for key, value in model_mapping.items() if key in model), FALLBACK_MODEL) |
|
|
|
async def get_deployment_details(apikey: str) -> str: |
|
if apikey in deployment_cache: |
|
return deployment_cache[apikey] |
|
|
|
async with cache_lock: |
|
if apikey in deployment_cache: |
|
return deployment_cache[apikey] |
|
|
|
headers = { |
|
'apiKey': apikey, |
|
'accept': '*/*', |
|
} |
|
|
|
response = await make_request( |
|
method="GET", |
|
url=f"https://{BASE_HOST}/api/listExternalApplications", |
|
headers=headers, |
|
data={} |
|
) |
|
|
|
result = response.json() |
|
logger.debug(f"List external applications result: {result}") |
|
|
|
if result.get("success") and result.get("result"): |
|
deployment_details = result["result"][0] |
|
deployment_cache[apikey] = deployment_details |
|
logger.info(f"#{deployment_details['deploymentId']} - Access granted successfully") |
|
return deployment_details |
|
else: |
|
raise HTTPException(status_code=500, detail="Failed to retrieve deployment info") |
|
|
|
async def create_conversation(apikey: str) -> str: |
|
deployment_details = await get_deployment_details(apikey) |
|
|
|
payload = { |
|
"deploymentId": deployment_details["deploymentId"], |
|
"name": "New Chat", |
|
"externalApplicationId": deployment_details["externalApplicationId"] |
|
} |
|
try: |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'apiKey': apikey, |
|
'REAI-UI': '1', |
|
'X-Abacus-Org-Host': 'apps' |
|
} |
|
response = await make_request( |
|
method="POST", |
|
url=f"https://{BASE_HOST}/api/createDeploymentConversation", |
|
headers=headers, |
|
data=payload |
|
) |
|
result = response.json() |
|
logger.debug(f"Create conversation result: {result}") |
|
|
|
if 'result' not in result or 'deploymentConversationId' not in result['result']: |
|
l |
|
raise HTTPException(status_code=401, detail="Invalid Abacus apikey") |
|
|
|
return result["result"]["deploymentConversationId"], deployment_details["deploymentId"] |
|
except Exception as e: |
|
logger.error(f"Error creating conversation: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}") |
|
|
|
def serialize_openai_messages(messages): |
|
def get_content(message): |
|
try: |
|
|
|
if 'content' not in message: |
|
return '' |
|
if not isinstance(message['content'], list): |
|
return message['content'] |
|
return message['content'][0]['text'] |
|
except KeyError as e: |
|
raise HTTPException(status_code=400, detail="Invalid request body") |
|
|
|
serialized_messages = [ |
|
f"{msg['role'].capitalize()}: {get_content(msg)}" |
|
for msg in messages |
|
] |
|
|
|
result = "\n\n".join(serialized_messages) |
|
|
|
result += "Assistant: {...}\n\n" |
|
|
|
return result.strip() |
|
|
|
CHAT_OUTPUT_PREFIX = 'data: {"id":"0","object":"0","created":0,"model":"0","choices":[{"index":0,"delta":{"content":' |
|
CHAT_OUTPUT_SUFFIX = '}}]}\n\n' |
|
ENDING_CHUNK = 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\ndata: [DONE]\n\n' |
|
|
|
NS_PREFIX = '{"id":"chatcmpl-123","object":"chat.completion","created":1694268190,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"' |
|
NS_SUFFIX = '"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"system_fingerprint":"0"}\n\n' |
|
|
|
async def stream_chat(apikey: str, conversation_id: str, body: Any, sse_flag=True): |
|
model = body["model"] |
|
messages = body["messages"] |
|
|
|
request_id = str(uuid.uuid4()) |
|
ws_url = f"wss://{BASE_HOST}/api/ws/chatLLMSendMessage?requestId={request_id}&docInfos=%5B%5D&deploymentConversationId={conversation_id}&llmName={model}&orgHost=apps" |
|
|
|
headers = { |
|
"apiKey": apikey, |
|
"Origin": f"https://{BASE_HOST}", |
|
} |
|
|
|
if sse_flag: |
|
data_prefix, data_suffix = CHAT_OUTPUT_PREFIX, CHAT_OUTPUT_SUFFIX |
|
_Jd = jsonDumps |
|
else: |
|
data_prefix, data_suffix = "", "" |
|
_Jd = lambda x: jsonDumps(x)[1:-1] |
|
yield NS_PREFIX |
|
|
|
try: |
|
async with websockets.connect(ws_url, extra_headers=headers) as websocket: |
|
serialized_msgs = serialize_openai_messages(messages) |
|
await websocket.send(jsonDumps({"message": serialized_msgs})) |
|
logger.debug(f"Sent message to WebSocket: {serialized_msgs}") |
|
|
|
async for response in websocket: |
|
logger.debug(f"Received WebSocket response: {response}") |
|
data = json.loads(response) |
|
|
|
if "segment" in data: |
|
segment = data['segment'] |
|
if data['type'] == "image_url": |
|
segment = f"\n" |
|
yield data_prefix |
|
yield _Jd(segment) |
|
yield data_suffix |
|
elif data.get("end", False): |
|
break |
|
|
|
yield (ENDING_CHUNK if sse_flag else NS_SUFFIX) |
|
except Exception as e: |
|
logger.error(f"Error in WebSocket communication: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"WebSocket error: {str(e)}") |
|
|
|
async def handle_chat_completion(request: Request): |
|
try: |
|
body = await request.json() |
|
logger.debug(f"Received request body: {body}") |
|
|
|
auth_header = request.headers.get("Authorization") |
|
|
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
raise HTTPException(status_code=401, detail="Invalid Authorization header") |
|
|
|
abacus_token = auth_header[7:] |
|
|
|
if not abacus_token: |
|
raise HTTPException(status_code=401, detail="Empty Authorization token") |
|
|
|
apikey = random.choice(abacus_token.split("|") or [abacus_token]) \ |
|
if ("|" in abacus_token) \ |
|
else abacus_token |
|
|
|
apikey = convert_unicode_escape(apikey.strip()) |
|
logger.debug(f"Parsed apikey: {apikey}") |
|
|
|
conversation_id, deployment_id = await create_conversation(apikey) |
|
logger.debug(f"Created conversation with ID: {conversation_id}") |
|
|
|
sse_flag = body.get("stream", (True if not "3.5" in body["model"] else False)) |
|
|
|
llm_name = map_model(body.get("model", "")) |
|
body["model"] = llm_name |
|
logger.info(f'#{deployment_id} - Querying {llm_name} in {("stream" if sse_flag else "non-stream")} mode') |
|
|
|
return StreamingResponse(stream_chat(apikey, conversation_id, body, sse_flag), |
|
media_type=("text/event-stream" if sse_flag else "application/json") + \ |
|
";charset=UTF-8") |
|
except Exception as e: |
|
logger.error(f"Error in chat_completions: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
@app.post("/hf/v1/chat/completions") |
|
@limiter.limit(RATE_LIMIT) |
|
async def chat_completions(request: Request) -> StreamingResponse: |
|
return await handle_chat_completion(request) |
|
|
|
def print_startup_info(): |
|
console = Console() |
|
table = Table(title="Environment Variables & Available Models") |
|
|
|
|
|
table.add_column("Category", style="green") |
|
table.add_column("Key", style="cyan") |
|
table.add_column("Value", style="magenta") |
|
|
|
|
|
table.add_row("[bold]Environment Variables[/bold]", "", "") |
|
for key, value in environment_variables.items(): |
|
table.add_row("", key, str(value)) |
|
|
|
|
|
table.add_row("", "", "") |
|
|
|
|
|
table.add_row("[bold]Available Models[/bold]", "", "") |
|
for short_name, full_name in model_mapping.items(): |
|
table.add_row("", short_name, full_name) |
|
|
|
|
|
console.print(table) |
|
|
|
if __name__ == "__main__": |
|
try: |
|
import uvloop |
|
except ImportError: |
|
uvloop = None |
|
if uvloop: |
|
uvloop.install() |
|
|
|
print_startup_info() |
|
|
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=PORT, access_log=False) |