dongsiqie commited on
Commit
948c356
·
verified ·
1 Parent(s): ab87c2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models: https://github.com/abacusai/api-python/blob/main/abacusai/api_class/enums.py
2
+ model_mapping = {
3
+ "sonnet": "CLAUDE_V3_5_SONNET",
4
+ "4o": "OPENAI_GPT4O",
5
+ "32": "OPENAI_GPT4_32K",
6
+ "turbo": "OPENAI_GPT4_128K_LATEST",
7
+ "vision": "OPENAI_GPT4_VISION",
8
+ "3.5": "OPENAI_GPT3_5",
9
+ "opus": "CLAUDE_V3_OPUS",
10
+ "haiku": "CLAUDE_V3_HAIKU",
11
+ "claude-2": "CLAUDE_V2_1",
12
+ "pro": "GEMINI_1_5_PRO",
13
+ "palm": "PALM",
14
+ "llama": "LLAMA3_LARGE_CHAT",
15
+ "_legacy_sonnet": "CLAUDE_V3_SONNET",
16
+ "_legacy_gemini": "GEMINI_PRO",
17
+ "_legacy_palm": "PALM_TEXT"
18
+ }
19
+
20
+ # requirements: fastapi, curl_cffi, cachetools, websockets, orjson, uvicorn, uvloop, slowapi
21
+ import os
22
+ set_env = lambda var_name, default=None: environment_variables.update({var_name: os.getenv(var_name, default)}) or os.getenv(var_name, default)
23
+ environment_variables = {}
24
+
25
+ # Define your environment variables using the set_env function
26
+ FALLBACK_MODEL = set_env("FALLBACK_LLM", "CLAUDE_V3_5_SONNET")
27
+ RATE_LIMIT = set_env("RATE_LIMIT", "1/4 second")
28
+ LOG_LEVEL = set_env("LOG_LEVEL", "INFO")
29
+ PORT = int(set_env("PORT", "8000"))
30
+ BASE_HOST = set_env("BASE_HOST", "apps.abacus.ai")
31
+
32
+ DEPLOYMENT_CACHE_TTL = 3600 * 24 # 24 hours
33
+ IMPERSONATE_BASE = "chrome"
34
+ CURL_MAX_CLIENTS = 300
35
+
36
+ import asyncio
37
+ import json
38
+ import uuid
39
+ import random
40
+ import logging
41
+ from typing import Dict, Any
42
+ from fastapi import FastAPI, HTTPException, Request
43
+ from fastapi.responses import StreamingResponse
44
+ from curl_cffi import requests, CurlOpt, CurlHttpVersion
45
+
46
+ from cachetools import TTLCache
47
+ deployment_cache = TTLCache(maxsize=300, ttl=DEPLOYMENT_CACHE_TTL)
48
+ cache_lock = asyncio.Lock()
49
+
50
+ import websockets
51
+
52
+ try:
53
+ import orjson as json
54
+ jsonDumps = lambda text: json.dumps(text).decode('utf-8')
55
+ except ImportError:
56
+ import json
57
+ jsonDumps = json.dumps
58
+ from slowapi import Limiter, _rate_limit_exceeded_handler
59
+ from slowapi.util import get_remote_address
60
+ from slowapi.errors import RateLimitExceeded
61
+
62
+ CURL_OPTS = {
63
+ CurlOpt.TCP_NODELAY: 1, CurlOpt.FORBID_REUSE: 0, CurlOpt.FRESH_CONNECT: 0, CurlOpt.TCP_KEEPALIVE: 1, CurlOpt.MAXAGE_CONN: 30
64
+ }
65
+ client = requests.AsyncSession(
66
+ impersonate=IMPERSONATE_BASE, default_headers=True, max_clients=CURL_MAX_CLIENTS, curl_options=CURL_OPTS, http_version=CurlHttpVersion.V2_PRIOR_KNOWLEDGE
67
+ )
68
+
69
+ from rich.logging import RichHandler
70
+ from rich.console import Console
71
+ from rich.table import Table
72
+
73
+ # Setup logger with RichHandler for better logging output
74
+ logging.basicConfig(
75
+ level=getattr(logging, LOG_LEVEL),
76
+ format="%(message)s",
77
+ datefmt="[%X]",
78
+ handlers=[RichHandler()]
79
+ )
80
+ logger = logging.getLogger(__name__)
81
+
82
+ app = FastAPI()
83
+
84
+ limiter = Limiter(key_func=get_remote_address)
85
+ app.state.limiter = limiter
86
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
87
+
88
+ def convert_unicode_escape(s):
89
+ return s.encode('utf-8').decode('unicode-escape')
90
+
91
+ async def make_request(method: str, url: str, headers: dict, data: dict):
92
+ try:
93
+ response = await client.request(method=method, url=url, headers=headers, json=data)
94
+ status = response.status_code
95
+ if status == 200:
96
+ return response
97
+ elif status in (401, 403):
98
+ raise HTTPException(status_code=401, detail="Invalid authorization info")
99
+ else:
100
+ raise HTTPException(status_code=status, detail=f"Network issue: {response.text}")
101
+ except Exception as e:
102
+ logger.error(f"Request error: {str(e)}", exc_info=True)
103
+ raise HTTPException(status_code=500, detail=f"Request error: {str(e)}")
104
+
105
+ def map_model(requestModel):
106
+ model = requestModel.lower()
107
+
108
+ if model.startswith('adv/'):
109
+ model = model[4:]
110
+ return model if model else FALLBACK_MODEL
111
+
112
+ return next((value for key, value in model_mapping.items() if key in model), FALLBACK_MODEL)
113
+
114
+ async def get_deployment_details(apikey: str) -> str:
115
+ if apikey in deployment_cache:
116
+ return deployment_cache[apikey]
117
+
118
+ async with cache_lock:
119
+ if apikey in deployment_cache:
120
+ return deployment_cache[apikey]
121
+
122
+ headers = {
123
+ 'apiKey': apikey,
124
+ 'accept': '*/*',
125
+ }
126
+
127
+ response = await make_request(
128
+ method="GET",
129
+ url=f"https://{BASE_HOST}/api/listExternalApplications",
130
+ headers=headers,
131
+ data={}
132
+ )
133
+
134
+ result = response.json()
135
+ logger.debug(f"List external applications result: {result}")
136
+
137
+ if result.get("success") and result.get("result"):
138
+ deployment_details = result["result"][0]
139
+ deployment_cache[apikey] = deployment_details
140
+ logger.info(f"#{deployment_details['deploymentId']} - Access granted successfully")
141
+ return deployment_details
142
+ else:
143
+ raise HTTPException(status_code=500, detail="Failed to retrieve deployment info")
144
+
145
+ async def create_conversation(apikey: str) -> str:
146
+ deployment_details = await get_deployment_details(apikey)
147
+
148
+ payload = {
149
+ "deploymentId": deployment_details["deploymentId"],
150
+ "name": "New Chat",
151
+ "externalApplicationId": deployment_details["externalApplicationId"]
152
+ }
153
+ try:
154
+ headers = {
155
+ 'Content-Type': 'application/json',
156
+ 'apiKey': apikey,
157
+ 'REAI-UI': '1',
158
+ 'X-Abacus-Org-Host': 'apps'
159
+ }
160
+ response = await make_request(
161
+ method="POST",
162
+ url=f"https://{BASE_HOST}/api/createDeploymentConversation",
163
+ headers=headers,
164
+ data=payload
165
+ )
166
+ result = response.json()
167
+ logger.debug(f"Create conversation result: {result}")
168
+
169
+ if 'result' not in result or 'deploymentConversationId' not in result['result']:
170
+ l#ogger.error(f"Unexpected response structure: {result}")
171
+ raise HTTPException(status_code=401, detail="Invalid Abacus apikey")
172
+
173
+ return result["result"]["deploymentConversationId"], deployment_details["deploymentId"]
174
+ except Exception as e:
175
+ logger.error(f"Error creating conversation: {str(e)}", exc_info=True)
176
+ raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}")
177
+
178
+ def serialize_openai_messages(messages):
179
+ def get_content(message):
180
+ try:
181
+ # Check if the 'content' key exists in message
182
+ if 'content' not in message:
183
+ return ''
184
+ if not isinstance(message['content'], list):
185
+ return message['content']
186
+ return message['content'][0]['text']
187
+ except KeyError as e:
188
+ raise HTTPException(status_code=400, detail="Invalid request body")
189
+
190
+ serialized_messages = [
191
+ f"{msg['role'].capitalize()}: {get_content(msg)}"
192
+ for msg in messages
193
+ ]
194
+
195
+ result = "\n\n".join(serialized_messages)
196
+
197
+ result += "Assistant: {...}\n\n"
198
+
199
+ return result.strip()
200
+
201
+ CHAT_OUTPUT_PREFIX = 'data: {"id":"0","object":"0","created":0,"model":"0","choices":[{"index":0,"delta":{"content":'
202
+ CHAT_OUTPUT_SUFFIX = '}}]}\n\n'
203
+ ENDING_CHUNK = 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\ndata: [DONE]\n\n'
204
+
205
+ NS_PREFIX = '{"id":"chatcmpl-123","object":"chat.completion","created":1694268190,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"'
206
+ NS_SUFFIX = '"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"system_fingerprint":"0"}\n\n'
207
+
208
+ async def stream_chat(apikey: str, conversation_id: str, body: Any, sse_flag=True):
209
+ model = body["model"]
210
+ messages = body["messages"]
211
+
212
+ request_id = str(uuid.uuid4())
213
+ ws_url = f"wss://{BASE_HOST}/api/ws/chatLLMSendMessage?requestId={request_id}&docInfos=%5B%5D&deploymentConversationId={conversation_id}&llmName={model}&orgHost=apps"
214
+
215
+ headers = {
216
+ "apiKey": apikey,
217
+ "Origin": f"https://{BASE_HOST}",
218
+ }
219
+
220
+ if sse_flag:
221
+ data_prefix, data_suffix = CHAT_OUTPUT_PREFIX, CHAT_OUTPUT_SUFFIX
222
+ _Jd = jsonDumps
223
+ else:
224
+ data_prefix, data_suffix = "", ""
225
+ _Jd = lambda x: jsonDumps(x)[1:-1]
226
+ yield NS_PREFIX
227
+
228
+ try:
229
+ async with websockets.connect(ws_url, extra_headers=headers) as websocket:
230
+ serialized_msgs = serialize_openai_messages(messages)
231
+ await websocket.send(jsonDumps({"message": serialized_msgs}))
232
+ logger.debug(f"Sent message to WebSocket: {serialized_msgs}")
233
+
234
+ async for response in websocket:
235
+ logger.debug(f"Received WebSocket response: {response}")
236
+ data = json.loads(response)
237
+
238
+ if "segment" in data:
239
+ segment = data['segment']
240
+ if data['type'] == "image_url":
241
+ segment = f"\n![Image]({segment})"
242
+ yield data_prefix
243
+ yield _Jd(segment)
244
+ yield data_suffix
245
+ elif data.get("end", False):
246
+ break
247
+
248
+ yield (ENDING_CHUNK if sse_flag else NS_SUFFIX)
249
+ except Exception as e:
250
+ logger.error(f"Error in WebSocket communication: {str(e)}", exc_info=True)
251
+ raise HTTPException(status_code=500, detail=f"WebSocket error: {str(e)}")
252
+
253
+ async def handle_chat_completion(request: Request):
254
+ try:
255
+ body = await request.json()
256
+ logger.debug(f"Received request body: {body}")
257
+
258
+ auth_header = request.headers.get("Authorization")
259
+
260
+ if not auth_header or not auth_header.startswith("Bearer "):
261
+ raise HTTPException(status_code=401, detail="Invalid Authorization header")
262
+
263
+ abacus_token = auth_header[7:] # Remove "Bearer " prefix
264
+
265
+ if not abacus_token:
266
+ raise HTTPException(status_code=401, detail="Empty Authorization token")
267
+
268
+ apikey = random.choice(abacus_token.split("|") or [abacus_token]) \
269
+ if ("|" in abacus_token) \
270
+ else abacus_token
271
+
272
+ apikey = convert_unicode_escape(apikey.strip())
273
+ logger.debug(f"Parsed apikey: {apikey}")
274
+
275
+ conversation_id, deployment_id = await create_conversation(apikey)
276
+ logger.debug(f"Created conversation with ID: {conversation_id}")
277
+
278
+ sse_flag = body.get("stream", (True if not "3.5" in body["model"] else False))
279
+
280
+ llm_name = map_model(body.get("model", ""))
281
+ body["model"] = llm_name
282
+ logger.info(f'#{deployment_id} - Querying {llm_name} in {("stream" if sse_flag else "non-stream")} mode')
283
+
284
+ return StreamingResponse(stream_chat(apikey, conversation_id, body, sse_flag),
285
+ media_type=("text/event-stream" if sse_flag else "application/json") + \
286
+ ";charset=UTF-8")
287
+ except Exception as e:
288
+ logger.error(f"Error in chat_completions: {str(e)}", exc_info=True)
289
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
290
+
291
+ @app.post("/hf/v1/chat/completions")
292
+ @limiter.limit(RATE_LIMIT)
293
+ async def chat_completions(request: Request) -> StreamingResponse:
294
+ return await handle_chat_completion(request)
295
+
296
+ def print_startup_info():
297
+ console = Console()
298
+ table = Table(title="Environment Variables & Available Models")
299
+
300
+ # Set up columns
301
+ table.add_column("Category", style="green")
302
+ table.add_column("Key", style="cyan")
303
+ table.add_column("Value", style="magenta")
304
+
305
+ # Add environment variables to the table
306
+ table.add_row("[bold]Environment Variables[/bold]", "", "")
307
+ for key, value in environment_variables.items():
308
+ table.add_row("", key, str(value))
309
+
310
+ # Add a separator row between the sections
311
+ table.add_row("", "", "")
312
+
313
+ # Add model mapping to the table
314
+ table.add_row("[bold]Available Models[/bold]", "", "")
315
+ for short_name, full_name in model_mapping.items():
316
+ table.add_row("", short_name, full_name)
317
+
318
+ # Print the table to the console
319
+ console.print(table)
320
+
321
+ if __name__ == "__main__":
322
+ try:
323
+ import uvloop
324
+ except ImportError:
325
+ uvloop = None
326
+ if uvloop:
327
+ uvloop.install()
328
+
329
+ print_startup_info()
330
+
331
+ import uvicorn
332
+ uvicorn.run(app, host="0.0.0.0", port=PORT, access_log=False)