a2o / app.py
dongsiqie's picture
Update app.py
2b3a3cf verified
# Models: https://github.com/abacusai/api-python/blob/main/abacusai/api_class/enums.py
model_mapping = {
"sonnet": "CLAUDE_V3_5_SONNET",
"4o": "OPENAI_GPT4O",
"32": "OPENAI_GPT4_32K",
"turbo": "OPENAI_GPT4_128K_LATEST",
"vision": "OPENAI_GPT4_VISION",
"3.5": "OPENAI_GPT3_5",
"opus": "CLAUDE_V3_OPUS",
"haiku": "CLAUDE_V3_HAIKU",
"claude-2": "CLAUDE_V2_1",
"pro": "GEMINI_1_5_PRO",
"palm": "PALM",
"llama": "LLAMA3_LARGE_CHAT",
"_legacy_sonnet": "CLAUDE_V3_SONNET",
"_legacy_gemini": "GEMINI_PRO",
"_legacy_palm": "PALM_TEXT",
"gpt-4o": "OPENAI_GPT4O",
"gpt-4-turbo": "OPENAI_GPT4_128K_LATEST",
"claude-3-opus": "CLAUDE_V3_OPUS"
}
# requirements: fastapi, curl_cffi, cachetools, websockets, orjson, uvicorn, uvloop, slowapi
import os
set_env = lambda var_name, default=None: environment_variables.update({var_name: os.getenv(var_name, default)}) or os.getenv(var_name, default)
environment_variables = {}
# Define your environment variables using the set_env function
FALLBACK_MODEL = set_env("FALLBACK_LLM", "CLAUDE_V3_5_SONNET")
RATE_LIMIT = set_env("RATE_LIMIT", "1/4 second")
LOG_LEVEL = set_env("LOG_LEVEL", "INFO")
PORT = int(set_env("PORT", "8000"))
BASE_HOST = set_env("BASE_HOST", "apps.abacus.ai")
DEPLOYMENT_CACHE_TTL = 3600 * 24 # 24 hours
IMPERSONATE_BASE = "chrome"
CURL_MAX_CLIENTS = 300
import asyncio
import json
import uuid
import random
import logging
from typing import Dict, Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from curl_cffi import requests, CurlOpt, CurlHttpVersion
from cachetools import TTLCache
deployment_cache = TTLCache(maxsize=300, ttl=DEPLOYMENT_CACHE_TTL)
cache_lock = asyncio.Lock()
import websockets
try:
import orjson as json
jsonDumps = lambda text: json.dumps(text).decode('utf-8')
except ImportError:
import json
jsonDumps = json.dumps
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
CURL_OPTS = {
CurlOpt.TCP_NODELAY: 1, CurlOpt.FORBID_REUSE: 0, CurlOpt.FRESH_CONNECT: 0, CurlOpt.TCP_KEEPALIVE: 1, CurlOpt.MAXAGE_CONN: 30
}
client = requests.AsyncSession(
impersonate=IMPERSONATE_BASE, default_headers=True, max_clients=CURL_MAX_CLIENTS, curl_options=CURL_OPTS, http_version=CurlHttpVersion.V2_PRIOR_KNOWLEDGE
)
from rich.logging import RichHandler
from rich.console import Console
from rich.table import Table
# Setup logger with RichHandler for better logging output
logging.basicConfig(
level=getattr(logging, LOG_LEVEL),
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()]
)
logger = logging.getLogger(__name__)
app = FastAPI()
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
def convert_unicode_escape(s):
return s.encode('utf-8').decode('unicode-escape')
async def make_request(method: str, url: str, headers: dict, data: dict):
try:
response = await client.request(method=method, url=url, headers=headers, json=data)
status = response.status_code
if status == 200:
return response
elif status in (401, 403):
raise HTTPException(status_code=401, detail="Invalid authorization info")
else:
raise HTTPException(status_code=status, detail=f"Network issue: {response.text}")
except Exception as e:
logger.error(f"Request error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Request error: {str(e)}")
def map_model(requestModel):
model = requestModel.lower()
if model.startswith('adv/'):
model = model[4:]
return model if model else FALLBACK_MODEL
return next((value for key, value in model_mapping.items() if key in model), FALLBACK_MODEL)
async def get_deployment_details(apikey: str) -> str:
if apikey in deployment_cache:
return deployment_cache[apikey]
async with cache_lock:
if apikey in deployment_cache:
return deployment_cache[apikey]
headers = {
'apiKey': apikey,
'accept': '*/*',
}
response = await make_request(
method="GET",
url=f"https://{BASE_HOST}/api/listExternalApplications",
headers=headers,
data={}
)
result = response.json()
logger.debug(f"List external applications result: {result}")
if result.get("success") and result.get("result"):
deployment_details = result["result"][0]
deployment_cache[apikey] = deployment_details
logger.info(f"#{deployment_details['deploymentId']} - Access granted successfully")
return deployment_details
else:
raise HTTPException(status_code=500, detail="Failed to retrieve deployment info")
async def create_conversation(apikey: str) -> str:
deployment_details = await get_deployment_details(apikey)
payload = {
"deploymentId": deployment_details["deploymentId"],
"name": "New Chat",
"externalApplicationId": deployment_details["externalApplicationId"]
}
try:
headers = {
'Content-Type': 'application/json',
'apiKey': apikey,
'REAI-UI': '1',
'X-Abacus-Org-Host': 'apps'
}
response = await make_request(
method="POST",
url=f"https://{BASE_HOST}/api/createDeploymentConversation",
headers=headers,
data=payload
)
result = response.json()
logger.debug(f"Create conversation result: {result}")
if 'result' not in result or 'deploymentConversationId' not in result['result']:
l#ogger.error(f"Unexpected response structure: {result}")
raise HTTPException(status_code=401, detail="Invalid Abacus apikey")
return result["result"]["deploymentConversationId"], deployment_details["deploymentId"]
except Exception as e:
logger.error(f"Error creating conversation: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}")
def serialize_openai_messages(messages):
def get_content(message):
try:
# Check if the 'content' key exists in message
if 'content' not in message:
return ''
if not isinstance(message['content'], list):
return message['content']
return message['content'][0]['text']
except KeyError as e:
raise HTTPException(status_code=400, detail="Invalid request body")
serialized_messages = [
f"{msg['role'].capitalize()}: {get_content(msg)}"
for msg in messages
]
result = "\n\n".join(serialized_messages)
result += "Assistant: {...}\n\n"
return result.strip()
CHAT_OUTPUT_PREFIX = 'data: {"id":"0","object":"0","created":0,"model":"0","choices":[{"index":0,"delta":{"content":'
CHAT_OUTPUT_SUFFIX = '}}]}\n\n'
ENDING_CHUNK = 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\ndata: [DONE]\n\n'
NS_PREFIX = '{"id":"chatcmpl-123","object":"chat.completion","created":1694268190,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"'
NS_SUFFIX = '"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"system_fingerprint":"0"}\n\n'
async def stream_chat(apikey: str, conversation_id: str, body: Any, sse_flag=True):
model = body["model"]
messages = body["messages"]
request_id = str(uuid.uuid4())
ws_url = f"wss://{BASE_HOST}/api/ws/chatLLMSendMessage?requestId={request_id}&docInfos=%5B%5D&deploymentConversationId={conversation_id}&llmName={model}&orgHost=apps"
headers = {
"apiKey": apikey,
"Origin": f"https://{BASE_HOST}",
}
if sse_flag:
data_prefix, data_suffix = CHAT_OUTPUT_PREFIX, CHAT_OUTPUT_SUFFIX
_Jd = jsonDumps
else:
data_prefix, data_suffix = "", ""
_Jd = lambda x: jsonDumps(x)[1:-1]
yield NS_PREFIX
try:
async with websockets.connect(ws_url, extra_headers=headers) as websocket:
serialized_msgs = serialize_openai_messages(messages)
await websocket.send(jsonDumps({"message": serialized_msgs}))
logger.debug(f"Sent message to WebSocket: {serialized_msgs}")
async for response in websocket:
logger.debug(f"Received WebSocket response: {response}")
data = json.loads(response)
if "segment" in data:
segment = data['segment']
if data['type'] == "image_url":
segment = f"\n![Image]({segment})"
yield data_prefix
yield _Jd(segment)
yield data_suffix
elif data.get("end", False):
break
yield (ENDING_CHUNK if sse_flag else NS_SUFFIX)
except Exception as e:
logger.error(f"Error in WebSocket communication: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"WebSocket error: {str(e)}")
async def handle_chat_completion(request: Request):
try:
body = await request.json()
logger.debug(f"Received request body: {body}")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid Authorization header")
abacus_token = auth_header[7:] # Remove "Bearer " prefix
if not abacus_token:
raise HTTPException(status_code=401, detail="Empty Authorization token")
apikey = random.choice(abacus_token.split("|") or [abacus_token]) \
if ("|" in abacus_token) \
else abacus_token
apikey = convert_unicode_escape(apikey.strip())
logger.debug(f"Parsed apikey: {apikey}")
conversation_id, deployment_id = await create_conversation(apikey)
logger.debug(f"Created conversation with ID: {conversation_id}")
sse_flag = body.get("stream", (True if not "3.5" in body["model"] else False))
llm_name = map_model(body.get("model", ""))
body["model"] = llm_name
logger.info(f'#{deployment_id} - Querying {llm_name} in {("stream" if sse_flag else "non-stream")} mode')
return StreamingResponse(stream_chat(apikey, conversation_id, body, sse_flag),
media_type=("text/event-stream" if sse_flag else "application/json") + \
";charset=UTF-8")
except Exception as e:
logger.error(f"Error in chat_completions: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/hf/v1/chat/completions")
@limiter.limit(RATE_LIMIT)
async def chat_completions(request: Request) -> StreamingResponse:
return await handle_chat_completion(request)
def print_startup_info():
console = Console()
table = Table(title="Environment Variables & Available Models")
# Set up columns
table.add_column("Category", style="green")
table.add_column("Key", style="cyan")
table.add_column("Value", style="magenta")
# Add environment variables to the table
table.add_row("[bold]Environment Variables[/bold]", "", "")
for key, value in environment_variables.items():
table.add_row("", key, str(value))
# Add a separator row between the sections
table.add_row("", "", "")
# Add model mapping to the table
table.add_row("[bold]Available Models[/bold]", "", "")
for short_name, full_name in model_mapping.items():
table.add_row("", short_name, full_name)
# Print the table to the console
console.print(table)
if __name__ == "__main__":
try:
import uvloop
except ImportError:
uvloop = None
if uvloop:
uvloop.install()
print_startup_info()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT, access_log=False)