Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Models: https://github.com/abacusai/api-python/blob/main/abacusai/api_class/enums.py
|
2 |
+
model_mapping = {
|
3 |
+
"sonnet": "CLAUDE_V3_5_SONNET",
|
4 |
+
"4o": "OPENAI_GPT4O",
|
5 |
+
"32": "OPENAI_GPT4_32K",
|
6 |
+
"turbo": "OPENAI_GPT4_128K_LATEST",
|
7 |
+
"vision": "OPENAI_GPT4_VISION",
|
8 |
+
"3.5": "OPENAI_GPT3_5",
|
9 |
+
"opus": "CLAUDE_V3_OPUS",
|
10 |
+
"haiku": "CLAUDE_V3_HAIKU",
|
11 |
+
"claude-2": "CLAUDE_V2_1",
|
12 |
+
"pro": "GEMINI_1_5_PRO",
|
13 |
+
"palm": "PALM",
|
14 |
+
"llama": "LLAMA3_LARGE_CHAT",
|
15 |
+
"_legacy_sonnet": "CLAUDE_V3_SONNET",
|
16 |
+
"_legacy_gemini": "GEMINI_PRO",
|
17 |
+
"_legacy_palm": "PALM_TEXT"
|
18 |
+
}
|
19 |
+
|
20 |
+
# requirements: fastapi, curl_cffi, cachetools, websockets, orjson, uvicorn, uvloop, slowapi
|
21 |
+
import os
|
22 |
+
set_env = lambda var_name, default=None: environment_variables.update({var_name: os.getenv(var_name, default)}) or os.getenv(var_name, default)
|
23 |
+
environment_variables = {}
|
24 |
+
|
25 |
+
# Define your environment variables using the set_env function
|
26 |
+
FALLBACK_MODEL = set_env("FALLBACK_LLM", "CLAUDE_V3_5_SONNET")
|
27 |
+
RATE_LIMIT = set_env("RATE_LIMIT", "1/4 second")
|
28 |
+
LOG_LEVEL = set_env("LOG_LEVEL", "INFO")
|
29 |
+
PORT = int(set_env("PORT", "8000"))
|
30 |
+
BASE_HOST = set_env("BASE_HOST", "apps.abacus.ai")
|
31 |
+
|
32 |
+
DEPLOYMENT_CACHE_TTL = 3600 * 24 # 24 hours
|
33 |
+
IMPERSONATE_BASE = "chrome"
|
34 |
+
CURL_MAX_CLIENTS = 300
|
35 |
+
|
36 |
+
import asyncio
|
37 |
+
import json
|
38 |
+
import uuid
|
39 |
+
import random
|
40 |
+
import logging
|
41 |
+
from typing import Dict, Any
|
42 |
+
from fastapi import FastAPI, HTTPException, Request
|
43 |
+
from fastapi.responses import StreamingResponse
|
44 |
+
from curl_cffi import requests, CurlOpt, CurlHttpVersion
|
45 |
+
|
46 |
+
from cachetools import TTLCache
|
47 |
+
deployment_cache = TTLCache(maxsize=300, ttl=DEPLOYMENT_CACHE_TTL)
|
48 |
+
cache_lock = asyncio.Lock()
|
49 |
+
|
50 |
+
import websockets
|
51 |
+
|
52 |
+
try:
|
53 |
+
import orjson as json
|
54 |
+
jsonDumps = lambda text: json.dumps(text).decode('utf-8')
|
55 |
+
except ImportError:
|
56 |
+
import json
|
57 |
+
jsonDumps = json.dumps
|
58 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
59 |
+
from slowapi.util import get_remote_address
|
60 |
+
from slowapi.errors import RateLimitExceeded
|
61 |
+
|
62 |
+
CURL_OPTS = {
|
63 |
+
CurlOpt.TCP_NODELAY: 1, CurlOpt.FORBID_REUSE: 0, CurlOpt.FRESH_CONNECT: 0, CurlOpt.TCP_KEEPALIVE: 1, CurlOpt.MAXAGE_CONN: 30
|
64 |
+
}
|
65 |
+
client = requests.AsyncSession(
|
66 |
+
impersonate=IMPERSONATE_BASE, default_headers=True, max_clients=CURL_MAX_CLIENTS, curl_options=CURL_OPTS, http_version=CurlHttpVersion.V2_PRIOR_KNOWLEDGE
|
67 |
+
)
|
68 |
+
|
69 |
+
from rich.logging import RichHandler
|
70 |
+
from rich.console import Console
|
71 |
+
from rich.table import Table
|
72 |
+
|
73 |
+
# Setup logger with RichHandler for better logging output
|
74 |
+
logging.basicConfig(
|
75 |
+
level=getattr(logging, LOG_LEVEL),
|
76 |
+
format="%(message)s",
|
77 |
+
datefmt="[%X]",
|
78 |
+
handlers=[RichHandler()]
|
79 |
+
)
|
80 |
+
logger = logging.getLogger(__name__)
|
81 |
+
|
82 |
+
app = FastAPI()
|
83 |
+
|
84 |
+
limiter = Limiter(key_func=get_remote_address)
|
85 |
+
app.state.limiter = limiter
|
86 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
87 |
+
|
88 |
+
def convert_unicode_escape(s):
|
89 |
+
return s.encode('utf-8').decode('unicode-escape')
|
90 |
+
|
91 |
+
async def make_request(method: str, url: str, headers: dict, data: dict):
|
92 |
+
try:
|
93 |
+
response = await client.request(method=method, url=url, headers=headers, json=data)
|
94 |
+
status = response.status_code
|
95 |
+
if status == 200:
|
96 |
+
return response
|
97 |
+
elif status in (401, 403):
|
98 |
+
raise HTTPException(status_code=401, detail="Invalid authorization info")
|
99 |
+
else:
|
100 |
+
raise HTTPException(status_code=status, detail=f"Network issue: {response.text}")
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Request error: {str(e)}", exc_info=True)
|
103 |
+
raise HTTPException(status_code=500, detail=f"Request error: {str(e)}")
|
104 |
+
|
105 |
+
def map_model(requestModel):
|
106 |
+
model = requestModel.lower()
|
107 |
+
|
108 |
+
if model.startswith('adv/'):
|
109 |
+
model = model[4:]
|
110 |
+
return model if model else FALLBACK_MODEL
|
111 |
+
|
112 |
+
return next((value for key, value in model_mapping.items() if key in model), FALLBACK_MODEL)
|
113 |
+
|
114 |
+
async def get_deployment_details(apikey: str) -> str:
|
115 |
+
if apikey in deployment_cache:
|
116 |
+
return deployment_cache[apikey]
|
117 |
+
|
118 |
+
async with cache_lock:
|
119 |
+
if apikey in deployment_cache:
|
120 |
+
return deployment_cache[apikey]
|
121 |
+
|
122 |
+
headers = {
|
123 |
+
'apiKey': apikey,
|
124 |
+
'accept': '*/*',
|
125 |
+
}
|
126 |
+
|
127 |
+
response = await make_request(
|
128 |
+
method="GET",
|
129 |
+
url=f"https://{BASE_HOST}/api/listExternalApplications",
|
130 |
+
headers=headers,
|
131 |
+
data={}
|
132 |
+
)
|
133 |
+
|
134 |
+
result = response.json()
|
135 |
+
logger.debug(f"List external applications result: {result}")
|
136 |
+
|
137 |
+
if result.get("success") and result.get("result"):
|
138 |
+
deployment_details = result["result"][0]
|
139 |
+
deployment_cache[apikey] = deployment_details
|
140 |
+
logger.info(f"#{deployment_details['deploymentId']} - Access granted successfully")
|
141 |
+
return deployment_details
|
142 |
+
else:
|
143 |
+
raise HTTPException(status_code=500, detail="Failed to retrieve deployment info")
|
144 |
+
|
145 |
+
async def create_conversation(apikey: str) -> str:
|
146 |
+
deployment_details = await get_deployment_details(apikey)
|
147 |
+
|
148 |
+
payload = {
|
149 |
+
"deploymentId": deployment_details["deploymentId"],
|
150 |
+
"name": "New Chat",
|
151 |
+
"externalApplicationId": deployment_details["externalApplicationId"]
|
152 |
+
}
|
153 |
+
try:
|
154 |
+
headers = {
|
155 |
+
'Content-Type': 'application/json',
|
156 |
+
'apiKey': apikey,
|
157 |
+
'REAI-UI': '1',
|
158 |
+
'X-Abacus-Org-Host': 'apps'
|
159 |
+
}
|
160 |
+
response = await make_request(
|
161 |
+
method="POST",
|
162 |
+
url=f"https://{BASE_HOST}/api/createDeploymentConversation",
|
163 |
+
headers=headers,
|
164 |
+
data=payload
|
165 |
+
)
|
166 |
+
result = response.json()
|
167 |
+
logger.debug(f"Create conversation result: {result}")
|
168 |
+
|
169 |
+
if 'result' not in result or 'deploymentConversationId' not in result['result']:
|
170 |
+
l#ogger.error(f"Unexpected response structure: {result}")
|
171 |
+
raise HTTPException(status_code=401, detail="Invalid Abacus apikey")
|
172 |
+
|
173 |
+
return result["result"]["deploymentConversationId"], deployment_details["deploymentId"]
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error creating conversation: {str(e)}", exc_info=True)
|
176 |
+
raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}")
|
177 |
+
|
178 |
+
def serialize_openai_messages(messages):
|
179 |
+
def get_content(message):
|
180 |
+
try:
|
181 |
+
# Check if the 'content' key exists in message
|
182 |
+
if 'content' not in message:
|
183 |
+
return ''
|
184 |
+
if not isinstance(message['content'], list):
|
185 |
+
return message['content']
|
186 |
+
return message['content'][0]['text']
|
187 |
+
except KeyError as e:
|
188 |
+
raise HTTPException(status_code=400, detail="Invalid request body")
|
189 |
+
|
190 |
+
serialized_messages = [
|
191 |
+
f"{msg['role'].capitalize()}: {get_content(msg)}"
|
192 |
+
for msg in messages
|
193 |
+
]
|
194 |
+
|
195 |
+
result = "\n\n".join(serialized_messages)
|
196 |
+
|
197 |
+
result += "Assistant: {...}\n\n"
|
198 |
+
|
199 |
+
return result.strip()
|
200 |
+
|
201 |
+
CHAT_OUTPUT_PREFIX = 'data: {"id":"0","object":"0","created":0,"model":"0","choices":[{"index":0,"delta":{"content":'
|
202 |
+
CHAT_OUTPUT_SUFFIX = '}}]}\n\n'
|
203 |
+
ENDING_CHUNK = 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\ndata: [DONE]\n\n'
|
204 |
+
|
205 |
+
NS_PREFIX = '{"id":"chatcmpl-123","object":"chat.completion","created":1694268190,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"'
|
206 |
+
NS_SUFFIX = '"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"system_fingerprint":"0"}\n\n'
|
207 |
+
|
208 |
+
async def stream_chat(apikey: str, conversation_id: str, body: Any, sse_flag=True):
|
209 |
+
model = body["model"]
|
210 |
+
messages = body["messages"]
|
211 |
+
|
212 |
+
request_id = str(uuid.uuid4())
|
213 |
+
ws_url = f"wss://{BASE_HOST}/api/ws/chatLLMSendMessage?requestId={request_id}&docInfos=%5B%5D&deploymentConversationId={conversation_id}&llmName={model}&orgHost=apps"
|
214 |
+
|
215 |
+
headers = {
|
216 |
+
"apiKey": apikey,
|
217 |
+
"Origin": f"https://{BASE_HOST}",
|
218 |
+
}
|
219 |
+
|
220 |
+
if sse_flag:
|
221 |
+
data_prefix, data_suffix = CHAT_OUTPUT_PREFIX, CHAT_OUTPUT_SUFFIX
|
222 |
+
_Jd = jsonDumps
|
223 |
+
else:
|
224 |
+
data_prefix, data_suffix = "", ""
|
225 |
+
_Jd = lambda x: jsonDumps(x)[1:-1]
|
226 |
+
yield NS_PREFIX
|
227 |
+
|
228 |
+
try:
|
229 |
+
async with websockets.connect(ws_url, extra_headers=headers) as websocket:
|
230 |
+
serialized_msgs = serialize_openai_messages(messages)
|
231 |
+
await websocket.send(jsonDumps({"message": serialized_msgs}))
|
232 |
+
logger.debug(f"Sent message to WebSocket: {serialized_msgs}")
|
233 |
+
|
234 |
+
async for response in websocket:
|
235 |
+
logger.debug(f"Received WebSocket response: {response}")
|
236 |
+
data = json.loads(response)
|
237 |
+
|
238 |
+
if "segment" in data:
|
239 |
+
segment = data['segment']
|
240 |
+
if data['type'] == "image_url":
|
241 |
+
segment = f"\n"
|
242 |
+
yield data_prefix
|
243 |
+
yield _Jd(segment)
|
244 |
+
yield data_suffix
|
245 |
+
elif data.get("end", False):
|
246 |
+
break
|
247 |
+
|
248 |
+
yield (ENDING_CHUNK if sse_flag else NS_SUFFIX)
|
249 |
+
except Exception as e:
|
250 |
+
logger.error(f"Error in WebSocket communication: {str(e)}", exc_info=True)
|
251 |
+
raise HTTPException(status_code=500, detail=f"WebSocket error: {str(e)}")
|
252 |
+
|
253 |
+
async def handle_chat_completion(request: Request):
|
254 |
+
try:
|
255 |
+
body = await request.json()
|
256 |
+
logger.debug(f"Received request body: {body}")
|
257 |
+
|
258 |
+
auth_header = request.headers.get("Authorization")
|
259 |
+
|
260 |
+
if not auth_header or not auth_header.startswith("Bearer "):
|
261 |
+
raise HTTPException(status_code=401, detail="Invalid Authorization header")
|
262 |
+
|
263 |
+
abacus_token = auth_header[7:] # Remove "Bearer " prefix
|
264 |
+
|
265 |
+
if not abacus_token:
|
266 |
+
raise HTTPException(status_code=401, detail="Empty Authorization token")
|
267 |
+
|
268 |
+
apikey = random.choice(abacus_token.split("|") or [abacus_token]) \
|
269 |
+
if ("|" in abacus_token) \
|
270 |
+
else abacus_token
|
271 |
+
|
272 |
+
apikey = convert_unicode_escape(apikey.strip())
|
273 |
+
logger.debug(f"Parsed apikey: {apikey}")
|
274 |
+
|
275 |
+
conversation_id, deployment_id = await create_conversation(apikey)
|
276 |
+
logger.debug(f"Created conversation with ID: {conversation_id}")
|
277 |
+
|
278 |
+
sse_flag = body.get("stream", (True if not "3.5" in body["model"] else False))
|
279 |
+
|
280 |
+
llm_name = map_model(body.get("model", ""))
|
281 |
+
body["model"] = llm_name
|
282 |
+
logger.info(f'#{deployment_id} - Querying {llm_name} in {("stream" if sse_flag else "non-stream")} mode')
|
283 |
+
|
284 |
+
return StreamingResponse(stream_chat(apikey, conversation_id, body, sse_flag),
|
285 |
+
media_type=("text/event-stream" if sse_flag else "application/json") + \
|
286 |
+
";charset=UTF-8")
|
287 |
+
except Exception as e:
|
288 |
+
logger.error(f"Error in chat_completions: {str(e)}", exc_info=True)
|
289 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
290 |
+
|
291 |
+
@app.post("/hf/v1/chat/completions")
|
292 |
+
@limiter.limit(RATE_LIMIT)
|
293 |
+
async def chat_completions(request: Request) -> StreamingResponse:
|
294 |
+
return await handle_chat_completion(request)
|
295 |
+
|
296 |
+
def print_startup_info():
|
297 |
+
console = Console()
|
298 |
+
table = Table(title="Environment Variables & Available Models")
|
299 |
+
|
300 |
+
# Set up columns
|
301 |
+
table.add_column("Category", style="green")
|
302 |
+
table.add_column("Key", style="cyan")
|
303 |
+
table.add_column("Value", style="magenta")
|
304 |
+
|
305 |
+
# Add environment variables to the table
|
306 |
+
table.add_row("[bold]Environment Variables[/bold]", "", "")
|
307 |
+
for key, value in environment_variables.items():
|
308 |
+
table.add_row("", key, str(value))
|
309 |
+
|
310 |
+
# Add a separator row between the sections
|
311 |
+
table.add_row("", "", "")
|
312 |
+
|
313 |
+
# Add model mapping to the table
|
314 |
+
table.add_row("[bold]Available Models[/bold]", "", "")
|
315 |
+
for short_name, full_name in model_mapping.items():
|
316 |
+
table.add_row("", short_name, full_name)
|
317 |
+
|
318 |
+
# Print the table to the console
|
319 |
+
console.print(table)
|
320 |
+
|
321 |
+
if __name__ == "__main__":
|
322 |
+
try:
|
323 |
+
import uvloop
|
324 |
+
except ImportError:
|
325 |
+
uvloop = None
|
326 |
+
if uvloop:
|
327 |
+
uvloop.install()
|
328 |
+
|
329 |
+
print_startup_info()
|
330 |
+
|
331 |
+
import uvicorn
|
332 |
+
uvicorn.run(app, host="0.0.0.0", port=PORT, access_log=False)
|