ParthSadaria commited on
Commit
3109050
·
verified ·
1 Parent(s): a68045e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -18
main.py CHANGED
@@ -4,18 +4,16 @@ from fastapi import FastAPI, HTTPException, Request
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from pydantic import BaseModel
6
  import httpx
 
7
 
8
- # Load environment variables from .env file
9
  load_dotenv()
10
 
11
  app = FastAPI()
12
 
13
- # Get API keys and secret endpoint from environment variables
14
  api_keys_str = os.getenv('API_KEYS')
15
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
16
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
17
 
18
- # Check if the endpoint is set in the environment
19
  if not secret_api_endpoint:
20
  raise HTTPException(status_code=500, detail="API endpoint is not configured in environment variables.")
21
 
@@ -50,8 +48,6 @@ async def root():
50
  """
51
  return HTMLResponse(content=html_content)
52
 
53
- # Remove cache from get_models
54
- @app.get("/v1/models")
55
  async def get_models():
56
  async with httpx.AsyncClient() as client:
57
  try:
@@ -61,33 +57,36 @@ async def get_models():
61
  except httpx.RequestError as e:
62
  raise HTTPException(status_code=500, detail=f"Request failed: {e}")
63
 
 
 
 
 
64
  @app.post("/v1/chat/completions")
65
  async def get_completion(payload: Payload, request: Request):
66
  api_key = request.headers.get("Authorization")
67
 
68
- # Validate API key
69
  if api_key not in valid_api_keys:
70
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
71
 
72
- # Prepare the payload for streaming
73
  payload_dict = {**payload.dict(), "stream": True}
74
 
75
- # Define an asynchronous generator to stream the response line by line
76
- async def stream_generator():
77
  async with httpx.AsyncClient() as client:
78
  try:
79
- async with client.stream("POST", secret_api_endpoint, json=payload_dict, timeout=10) as response:
80
  response.raise_for_status()
81
  async for line in response.aiter_lines():
82
  if line:
83
- yield f"{line}\n" # Add a newline to distinguish each line
84
- except httpx.RequestError as e:
85
- raise HTTPException(status_code=500, detail=f"Streaming failed: {e}")
86
-
87
- # Return the streaming response
88
- return StreamingResponse(stream_generator(), media_type="application/json")
 
 
 
89
 
90
- # Log the API endpoints
91
  @app.on_event("startup")
92
  async def startup_event():
93
  print("API endpoints:")
@@ -95,7 +94,6 @@ async def startup_event():
95
  print("GET /models")
96
  print("POST /v1/chat/completions")
97
 
98
- # Run the server with Uvicorn using the 'main' module
99
  if __name__ == "__main__":
100
  import uvicorn
101
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from pydantic import BaseModel
6
  import httpx
7
+ from functools import lru_cache
8
 
 
9
  load_dotenv()
10
 
11
  app = FastAPI()
12
 
 
13
  api_keys_str = os.getenv('API_KEYS')
14
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
15
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
16
 
 
17
  if not secret_api_endpoint:
18
  raise HTTPException(status_code=500, detail="API endpoint is not configured in environment variables.")
19
 
 
48
  """
49
  return HTMLResponse(content=html_content)
50
 
 
 
51
  async def get_models():
52
  async with httpx.AsyncClient() as client:
53
  try:
 
57
  except httpx.RequestError as e:
58
  raise HTTPException(status_code=500, detail=f"Request failed: {e}")
59
 
60
+ @app.get("/v1/models")
61
+ async def fetch_models():
62
+ return await get_models()
63
+
64
  @app.post("/v1/chat/completions")
65
  async def get_completion(payload: Payload, request: Request):
66
  api_key = request.headers.get("Authorization")
67
 
 
68
  if api_key not in valid_api_keys:
69
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
70
 
 
71
  payload_dict = {**payload.dict(), "stream": True}
72
 
73
+ async def stream_generator(payload_dict):
 
74
  async with httpx.AsyncClient() as client:
75
  try:
76
+ async with client.stream("POST", f"{secret_api_endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
77
  response.raise_for_status()
78
  async for line in response.aiter_lines():
79
  if line:
80
+ yield f"{line}\n"
81
+ except httpx.HTTPStatusError as status_err:
82
+ raise HTTPException(status_code=status_err.response.status_code, detail=f"HTTP error: {status_err}")
83
+ except httpx.RequestError as req_err:
84
+ raise HTTPException(status_code=500, detail=f"Streaming failed: {req_err}")
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
87
+
88
+ return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
89
 
 
90
  @app.on_event("startup")
91
  async def startup_event():
92
  print("API endpoints:")
 
94
  print("GET /models")
95
  print("POST /v1/chat/completions")
96
 
 
97
  if __name__ == "__main__":
98
  import uvicorn
99
  uvicorn.run(app, host="0.0.0.0", port=8000)