# Models: https://github.com/abacusai/api-python/blob/main/abacusai/api_class/enums.py model_mapping = { "sonnet": "CLAUDE_V3_5_SONNET", "4o": "OPENAI_GPT4O", "32": "OPENAI_GPT4_32K", "turbo": "OPENAI_GPT4_128K_LATEST", "vision": "OPENAI_GPT4_VISION", "3.5": "OPENAI_GPT3_5", "opus": "CLAUDE_V3_OPUS", "haiku": "CLAUDE_V3_HAIKU", "claude-2": "CLAUDE_V2_1", "pro": "GEMINI_1_5_PRO", "palm": "PALM", "llama": "LLAMA3_LARGE_CHAT", "_legacy_sonnet": "CLAUDE_V3_SONNET", "_legacy_gemini": "GEMINI_PRO", "_legacy_palm": "PALM_TEXT", "gpt-4o": "OPENAI_GPT4O", "gpt-4-turbo": "OPENAI_GPT4_128K_LATEST", "claude-3-opus": "CLAUDE_V3_OPUS" } # requirements: fastapi, curl_cffi, cachetools, websockets, orjson, uvicorn, uvloop, slowapi import os set_env = lambda var_name, default=None: environment_variables.update({var_name: os.getenv(var_name, default)}) or os.getenv(var_name, default) environment_variables = {} # Define your environment variables using the set_env function FALLBACK_MODEL = set_env("FALLBACK_LLM", "CLAUDE_V3_5_SONNET") RATE_LIMIT = set_env("RATE_LIMIT", "1/4 second") LOG_LEVEL = set_env("LOG_LEVEL", "INFO") PORT = int(set_env("PORT", "8000")) BASE_HOST = set_env("BASE_HOST", "apps.abacus.ai") DEPLOYMENT_CACHE_TTL = 3600 * 24 # 24 hours IMPERSONATE_BASE = "chrome" CURL_MAX_CLIENTS = 300 import asyncio import json import uuid import random import logging from typing import Dict, Any from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from curl_cffi import requests, CurlOpt, CurlHttpVersion from cachetools import TTLCache deployment_cache = TTLCache(maxsize=300, ttl=DEPLOYMENT_CACHE_TTL) cache_lock = asyncio.Lock() import websockets try: import orjson as json jsonDumps = lambda text: json.dumps(text).decode('utf-8') except ImportError: import json jsonDumps = json.dumps from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded CURL_OPTS = { CurlOpt.TCP_NODELAY: 1, CurlOpt.FORBID_REUSE: 0, CurlOpt.FRESH_CONNECT: 0, CurlOpt.TCP_KEEPALIVE: 1, CurlOpt.MAXAGE_CONN: 30 } client = requests.AsyncSession( impersonate=IMPERSONATE_BASE, default_headers=True, max_clients=CURL_MAX_CLIENTS, curl_options=CURL_OPTS, http_version=CurlHttpVersion.V2_PRIOR_KNOWLEDGE ) from rich.logging import RichHandler from rich.console import Console from rich.table import Table # Setup logger with RichHandler for better logging output logging.basicConfig( level=getattr(logging, LOG_LEVEL), format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] ) logger = logging.getLogger(__name__) app = FastAPI() limiter = Limiter(key_func=get_remote_address) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) def convert_unicode_escape(s): return s.encode('utf-8').decode('unicode-escape') async def make_request(method: str, url: str, headers: dict, data: dict): try: response = await client.request(method=method, url=url, headers=headers, json=data) status = response.status_code if status == 200: return response elif status in (401, 403): raise HTTPException(status_code=401, detail="Invalid authorization info") else: raise HTTPException(status_code=status, detail=f"Network issue: {response.text}") except Exception as e: logger.error(f"Request error: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Request error: {str(e)}") def map_model(requestModel): model = requestModel.lower() if model.startswith('adv/'): model = model[4:] return model if model else FALLBACK_MODEL return next((value for key, value in model_mapping.items() if key in model), FALLBACK_MODEL) async def get_deployment_details(apikey: str) -> str: if apikey in deployment_cache: return deployment_cache[apikey] async with cache_lock: if apikey in deployment_cache: return deployment_cache[apikey] headers = { 'apiKey': apikey, 'accept': '*/*', } response = await make_request( method="GET", url=f"https://{BASE_HOST}/api/listExternalApplications", headers=headers, data={} ) result = response.json() logger.debug(f"List external applications result: {result}") if result.get("success") and result.get("result"): deployment_details = result["result"][0] deployment_cache[apikey] = deployment_details logger.info(f"#{deployment_details['deploymentId']} - Access granted successfully") return deployment_details else: raise HTTPException(status_code=500, detail="Failed to retrieve deployment info") async def create_conversation(apikey: str) -> str: deployment_details = await get_deployment_details(apikey) payload = { "deploymentId": deployment_details["deploymentId"], "name": "New Chat", "externalApplicationId": deployment_details["externalApplicationId"] } try: headers = { 'Content-Type': 'application/json', 'apiKey': apikey, 'REAI-UI': '1', 'X-Abacus-Org-Host': 'apps' } response = await make_request( method="POST", url=f"https://{BASE_HOST}/api/createDeploymentConversation", headers=headers, data=payload ) result = response.json() logger.debug(f"Create conversation result: {result}") if 'result' not in result or 'deploymentConversationId' not in result['result']: l#ogger.error(f"Unexpected response structure: {result}") raise HTTPException(status_code=401, detail="Invalid Abacus apikey") return result["result"]["deploymentConversationId"], deployment_details["deploymentId"] except Exception as e: logger.error(f"Error creating conversation: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}") def serialize_openai_messages(messages): def get_content(message): try: # Check if the 'content' key exists in message if 'content' not in message: return '' if not isinstance(message['content'], list): return message['content'] return message['content'][0]['text'] except KeyError as e: raise HTTPException(status_code=400, detail="Invalid request body") serialized_messages = [ f"{msg['role'].capitalize()}: {get_content(msg)}" for msg in messages ] result = "\n\n".join(serialized_messages) result += "Assistant: {...}\n\n" return result.strip() CHAT_OUTPUT_PREFIX = 'data: {"id":"0","object":"0","created":0,"model":"0","choices":[{"index":0,"delta":{"content":' CHAT_OUTPUT_SUFFIX = '}}]}\n\n' ENDING_CHUNK = 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\ndata: [DONE]\n\n' NS_PREFIX = '{"id":"chatcmpl-123","object":"chat.completion","created":1694268190,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"' NS_SUFFIX = '"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"system_fingerprint":"0"}\n\n' async def stream_chat(apikey: str, conversation_id: str, body: Any, sse_flag=True): model = body["model"] messages = body["messages"] request_id = str(uuid.uuid4()) ws_url = f"wss://{BASE_HOST}/api/ws/chatLLMSendMessage?requestId={request_id}&docInfos=%5B%5D&deploymentConversationId={conversation_id}&llmName={model}&orgHost=apps" headers = { "apiKey": apikey, "Origin": f"https://{BASE_HOST}", } if sse_flag: data_prefix, data_suffix = CHAT_OUTPUT_PREFIX, CHAT_OUTPUT_SUFFIX _Jd = jsonDumps else: data_prefix, data_suffix = "", "" _Jd = lambda x: jsonDumps(x)[1:-1] yield NS_PREFIX try: async with websockets.connect(ws_url, extra_headers=headers) as websocket: serialized_msgs = serialize_openai_messages(messages) await websocket.send(jsonDumps({"message": serialized_msgs})) logger.debug(f"Sent message to WebSocket: {serialized_msgs}") async for response in websocket: logger.debug(f"Received WebSocket response: {response}") data = json.loads(response) if "segment" in data: segment = data['segment'] if data['type'] == "image_url": segment = f"\n![Image]({segment})" yield data_prefix yield _Jd(segment) yield data_suffix elif data.get("end", False): break yield (ENDING_CHUNK if sse_flag else NS_SUFFIX) except Exception as e: logger.error(f"Error in WebSocket communication: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"WebSocket error: {str(e)}") async def handle_chat_completion(request: Request): try: body = await request.json() logger.debug(f"Received request body: {body}") auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid Authorization header") abacus_token = auth_header[7:] # Remove "Bearer " prefix if not abacus_token: raise HTTPException(status_code=401, detail="Empty Authorization token") apikey = random.choice(abacus_token.split("|") or [abacus_token]) \ if ("|" in abacus_token) \ else abacus_token apikey = convert_unicode_escape(apikey.strip()) logger.debug(f"Parsed apikey: {apikey}") conversation_id, deployment_id = await create_conversation(apikey) logger.debug(f"Created conversation with ID: {conversation_id}") sse_flag = body.get("stream", (True if not "3.5" in body["model"] else False)) llm_name = map_model(body.get("model", "")) body["model"] = llm_name logger.info(f'#{deployment_id} - Querying {llm_name} in {("stream" if sse_flag else "non-stream")} mode') return StreamingResponse(stream_chat(apikey, conversation_id, body, sse_flag), media_type=("text/event-stream" if sse_flag else "application/json") + \ ";charset=UTF-8") except Exception as e: logger.error(f"Error in chat_completions: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/hf/v1/chat/completions") @limiter.limit(RATE_LIMIT) async def chat_completions(request: Request) -> StreamingResponse: return await handle_chat_completion(request) def print_startup_info(): console = Console() table = Table(title="Environment Variables & Available Models") # Set up columns table.add_column("Category", style="green") table.add_column("Key", style="cyan") table.add_column("Value", style="magenta") # Add environment variables to the table table.add_row("[bold]Environment Variables[/bold]", "", "") for key, value in environment_variables.items(): table.add_row("", key, str(value)) # Add a separator row between the sections table.add_row("", "", "") # Add model mapping to the table table.add_row("[bold]Available Models[/bold]", "", "") for short_name, full_name in model_mapping.items(): table.add_row("", short_name, full_name) # Print the table to the console console.print(table) if __name__ == "__main__": try: import uvloop except ImportError: uvloop = None if uvloop: uvloop.install() print_startup_info() import uvicorn uvicorn.run(app, host="0.0.0.0", port=PORT, access_log=False)