Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse | |
| from pydantic import BaseModel | |
| import httpx | |
| from pathlib import Path # Import Path from pathlib | |
| import requests | |
| import re | |
| import cloudscraper | |
| import json | |
| from typing import Optional | |
| import datetime | |
| from usage_tracker import UsageTracker | |
| usage_tracker = UsageTracker() | |
| load_dotenv() #idk why this shi | |
| app = FastAPI() | |
| # Get API keys and secret endpoint from environment variables | |
| api_keys_str = os.getenv('API_KEYS') #deprecated -_- | |
| valid_api_keys = api_keys_str.split(',') if api_keys_str else [] | |
| secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT') | |
| secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2') | |
| secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt | |
| image_endpoint = os.getenv("IMAGE_ENDPOINT") | |
| ENDPOINT_ORIGIN = os.getenv('ENDPOINT_ORIGIN') | |
| # Validate if the main secret API endpoints are set | |
| if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3: | |
| raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.") | |
| # Define models that should use the secondary endpoint | |
| alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"} | |
| available_model_ids = [] | |
| class Payload(BaseModel): | |
| model: str | |
| messages: list | |
| stream: bool | |
| async def favicon(): | |
| # The favicon.ico file is in the same directory as the app | |
| favicon_path = Path(__file__).parent / "favicon.ico" | |
| return FileResponse(favicon_path, media_type="image/x-icon") | |
| def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> str: | |
| headers = {"User-Agent": ""} | |
| # Use the provided system prompt, or default to "Be Helpful and Friendly" | |
| system_message = systemprompt or "Be Helpful and Friendly" | |
| # Create the prompt history with the user query and system message | |
| prompt = [ | |
| {"role": "user", "content": query}, | |
| ] | |
| prompt.insert(0, {"content": system_message, "role": "system"}) | |
| # Prepare the payload for the API request | |
| payload = { | |
| "is_vscode_extension": True, | |
| "message_history": prompt, | |
| "requested_model": "searchgpt", | |
| "user_input": prompt[-1]["content"], | |
| } | |
| # Send the request to the chat endpoint | |
| response = requests.post(secret_api_endpoint_3, headers=headers, json=payload, stream=True) | |
| streaming_text = "" | |
| # Process the streaming response | |
| for value in response.iter_lines(decode_unicode=True): | |
| if value.startswith("data: "): | |
| try: | |
| json_modified_value = json.loads(value[6:]) | |
| content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
| if content.strip(): # Only process non-empty content | |
| cleaned_response = { | |
| "created": json_modified_value.get("created"), | |
| "id": json_modified_value.get("id"), | |
| "model": "searchgpt", | |
| "object": "chat.completion", | |
| "choices": [ | |
| { | |
| "message": { | |
| "content": content | |
| } | |
| } | |
| ] | |
| } | |
| if stream: | |
| yield f"data: {json.dumps(cleaned_response)}\n\n" | |
| streaming_text += content | |
| except json.JSONDecodeError: | |
| continue | |
| if not stream: | |
| yield streaming_text | |
| async def ping(): | |
| start_time = datetime.datetime.now() | |
| response_time = (datetime.datetime.now() - start_time).total_seconds() | |
| return {"message": "pong", "response_time": f"{response_time:.6f} seconds"} | |
| async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None): | |
| if not q: | |
| raise HTTPException(status_code=400, detail="Query parameter 'q' is required") | |
| usage_tracker.record_request(endpoint="/searchgpt") | |
| if stream: | |
| return StreamingResponse( | |
| generate_search(q, systemprompt=systemprompt, stream=True), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # For non-streaming, collect the text and return as JSON response | |
| response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)]) | |
| return JSONResponse(content={"response": response_text}) | |
| async def root(): | |
| # Open and read the content of index.html (in the same folder as the app) | |
| file_path = "index.html" | |
| try: | |
| with open(file_path, "r") as file: | |
| html_content = file.read() | |
| return HTMLResponse(content=html_content) | |
| except FileNotFoundError: | |
| return HTMLResponse(content="<h1>File not found</h1>", status_code=404) | |
| async def get_models(): | |
| try: | |
| # Load the models from models.json in the same folder | |
| file_path = Path(__file__).parent / 'models.json' | |
| with open(file_path, 'r') as f: | |
| return json.load(f) | |
| except FileNotFoundError: | |
| raise HTTPException(status_code=404, detail="models.json not found") | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=500, detail="Error decoding models.json") | |
| async def fetch_models(): | |
| return await get_models() | |
| server_status = True | |
| async def get_completion(payload: Payload, request: Request): | |
| # Check server status | |
| if not server_status: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"message": "Server is under maintenance. Please try again later."} | |
| ) | |
| model_to_use = payload.model if payload.model else "gpt-4o-mini" | |
| # Validate model availability | |
| if model_to_use not in available_model_ids: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model '{model_to_use}' is not available. Check /models for the available model list." | |
| ) | |
| usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions") | |
| # Prepare payload | |
| payload_dict = payload.dict() | |
| payload_dict["model"] = model_to_use | |
| # Select the appropriate endpoint | |
| endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint | |
| # Current time and IP logging | |
| current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p") | |
| aaip = request.client.host | |
| print(f"Time: {current_time}, {aaip}") | |
| # print(payload_dict) | |
| scraper = cloudscraper.create_scraper() | |
| async def stream_generator(payload_dict): | |
| # Prepare custom headers | |
| custom_headers = { | |
| 'DNT': '1', | |
| # 'Origin': ENDPOINT_ORIGIN, | |
| 'Priority': 'u=1, i', | |
| # 'Referer': ENDPOINT_ORIGIN | |
| } | |
| try: | |
| # Send POST request using CloudScraper with custom headers | |
| response = scraper.post( | |
| f"{endpoint}/v1/chat/completions", | |
| json=payload_dict, | |
| headers=custom_headers, | |
| stream=True | |
| ) | |
| # Error handling remains the same as in previous version | |
| if response.status_code == 422: | |
| raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.") | |
| elif response.status_code == 400: | |
| raise HTTPException(status_code=400, detail="Bad request. Verify input data.") | |
| elif response.status_code == 403: | |
| raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.") | |
| elif response.status_code == 404: | |
| raise HTTPException(status_code=404, detail="The requested resource was not found.") | |
| elif response.status_code >= 500: | |
| raise HTTPException(status_code=500, detail="Server error. Try again later.") | |
| # Stream response lines to the client | |
| for line in response.iter_lines(): | |
| if line: | |
| yield line.decode('utf-8') + "\n" | |
| except requests.exceptions.RequestException as req_err: | |
| # Handle request-specific errors | |
| print(response.text) | |
| raise HTTPException(status_code=500, detail=f"Request failed: {req_err}") | |
| except Exception as e: | |
| # Handle unexpected errors | |
| print(response.text) | |
| raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") | |
| return StreamingResponse(stream_generator(payload_dict), media_type="application/json") | |
| # Remove the duplicated endpoint and combine the functionality | |
| # Support both GET and POST | |
| async def generate_image( | |
| prompt: Optional[str] = None, | |
| model: str = "flux", # Default model | |
| seed: Optional[int] = None, | |
| width: Optional[int] = None, | |
| height: Optional[int] = None, | |
| nologo: Optional[bool] = True, | |
| private: Optional[bool] = None, | |
| enhance: Optional[bool] = None, | |
| request: Request = None, # Access raw POST data | |
| ): | |
| """ | |
| Generate an image using the Image Generation API. | |
| """ | |
| # Validate the image endpoint | |
| if not image_endpoint: | |
| raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.") | |
| usage_tracker.record_request(endpoint="/images/generations") | |
| # Handle GET and POST prompts | |
| if request.method == "POST": | |
| try: | |
| body = await request.json() # Parse JSON body | |
| prompt = body.get("prompt", "").strip() | |
| if not prompt: | |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON payload") | |
| elif request.method == "GET": | |
| if not prompt or not prompt.strip(): | |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
| prompt = prompt.strip() | |
| # Sanitize and encode the prompt | |
| encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt') | |
| # Construct the URL with the encoded prompt | |
| base_url = image_endpoint.rstrip('/') # Remove trailing slash if present | |
| url = f"{base_url}/{encoded_prompt}" | |
| # Prepare query parameters with validation | |
| params = {} | |
| if model and isinstance(model, str): | |
| params['model'] = model | |
| if seed is not None and isinstance(seed, int): | |
| params['seed'] = seed | |
| if width is not None and isinstance(width, int) and 64 <= width <= 2048: | |
| params['width'] = width | |
| if height is not None and isinstance(height, int) and 64 <= height <= 2048: | |
| params['height'] = height | |
| if nologo is not None: | |
| params['nologo'] = str(nologo).lower() | |
| if private is not None: | |
| params['private'] = str(private).lower() | |
| if enhance is not None: | |
| params['enhance'] = str(enhance).lower() | |
| try: | |
| timeout = httpx.Timeout(60.0) # Set a reasonable timeout | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| response = await client.get(url, params=params, follow_redirects=True) | |
| # Check for various error conditions | |
| if response.status_code == 404: | |
| raise HTTPException(status_code=404, detail="Image generation service not found") | |
| elif response.status_code == 400: | |
| raise HTTPException(status_code=400, detail="Invalid parameters provided to image service") | |
| elif response.status_code == 429: | |
| raise HTTPException(status_code=429, detail="Too many requests to image service") | |
| elif response.status_code != 200: | |
| raise HTTPException( | |
| status_code=response.status_code, | |
| detail=f"Image generation failed with status code {response.status_code}" | |
| ) | |
| # Verify content type | |
| content_type = response.headers.get('content-type', '') | |
| if not content_type.startswith('image/'): | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected content type received: {content_type}" | |
| ) | |
| return StreamingResponse( | |
| response.iter_bytes(), | |
| media_type=content_type, | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'Pragma': 'no-cache' | |
| } | |
| ) | |
| except httpx.TimeoutException: | |
| raise HTTPException(status_code=504, detail="Image generation request timed out") | |
| except httpx.RequestError as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to contact image service: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Unexpected error during image generation: {str(e)}") | |
| async def playground(): | |
| # Open and read the content of playground.html (in the same folder as the app) | |
| file_path = "playground.html" | |
| try: | |
| with open(file_path, "r") as file: | |
| html_content = file.read() | |
| return HTMLResponse(content=html_content) | |
| except FileNotFoundError: | |
| return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404) | |
| def load_model_ids(json_file_path): | |
| try: | |
| with open(json_file_path, 'r') as f: | |
| models_data = json.load(f) | |
| # Extract 'id' from each model object | |
| model_ids = [model['id'] for model in models_data if 'id' in model] | |
| return model_ids | |
| except FileNotFoundError: | |
| print("Error: models.json file not found.") | |
| return [] | |
| except json.JSONDecodeError: | |
| print("Error: Invalid JSON format in models.json.") | |
| return [] | |
| async def get_usage(days: int = 7): | |
| """Retrieve usage statistics""" | |
| return usage_tracker.get_usage_summary(days) | |
| async def usage_page(): | |
| """Serve an HTML page showing usage statistics""" | |
| # Retrieve usage data | |
| usage_data = usage_tracker.get_usage_summary() | |
| # Model Usage Table Rows | |
| model_usage_rows = "\n".join([ | |
| f""" | |
| <tr> | |
| <td>{model}</td> | |
| <td>{model_data['total_requests']}</td> | |
| <td>{model_data['first_used']}</td> | |
| <td>{model_data['last_used']}</td> | |
| </tr> | |
| """ for model, model_data in usage_data['models'].items() | |
| ]) | |
| # API Endpoint Usage Table Rows | |
| api_usage_rows = "\n".join([ | |
| f""" | |
| <tr> | |
| <td>{endpoint}</td> | |
| <td>{endpoint_data['total_requests']}</td> | |
| <td>{endpoint_data['first_used']}</td> | |
| <td>{endpoint_data['last_used']}</td> | |
| </tr> | |
| """ for endpoint, endpoint_data in usage_data['api_endpoints'].items() | |
| ]) | |
| # Daily Usage Table Rows | |
| daily_usage_rows = "\n".join([ | |
| "\n".join([ | |
| f""" | |
| <tr> | |
| <td>{date}</td> | |
| <td>{entity}</td> | |
| <td>{requests}</td> | |
| </tr> | |
| """ for entity, requests in date_data.items() | |
| ]) for date, date_data in usage_data['recent_daily_usage'].items() | |
| ]) | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Lokiai AI - Usage Statistics</title> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet"> | |
| <style> | |
| :root {{ | |
| --bg-dark: #0f1011; | |
| --bg-darker: #070708; | |
| --text-primary: #e6e6e6; | |
| --text-secondary: #8c8c8c; | |
| --border-color: #2c2c2c; | |
| --accent-color: #3a6ee0; | |
| --accent-hover: #4a7ef0; | |
| }} | |
| body {{ | |
| font-family: 'Inter', sans-serif; | |
| background-color: var(--bg-dark); | |
| color: var(--text-primary); | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 40px 20px; | |
| line-height: 1.6; | |
| }} | |
| .logo {{ | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| margin-bottom: 30px; | |
| }} | |
| .logo h1 {{ | |
| font-weight: 600; | |
| font-size: 2.5em; | |
| color: var(--text-primary); | |
| margin-left: 15px; | |
| }} | |
| .logo img {{ | |
| width: 60px; | |
| height: 60px; | |
| border-radius: 10px; | |
| }} | |
| .container {{ | |
| background-color: var(--bg-darker); | |
| border-radius: 12px; | |
| padding: 30px; | |
| box-shadow: 0 15px 40px rgba(0,0,0,0.3); | |
| border: 1px solid var(--border-color); | |
| }} | |
| h2, h3 {{ | |
| color: var(--text-primary); | |
| border-bottom: 2px solid var(--border-color); | |
| padding-bottom: 10px; | |
| font-weight: 500; | |
| }} | |
| .total-requests {{ | |
| background-color: var(--accent-color); | |
| color: white; | |
| text-align: center; | |
| padding: 15px; | |
| border-radius: 8px; | |
| margin-bottom: 30px; | |
| font-weight: 600; | |
| letter-spacing: -0.5px; | |
| }} | |
| table {{ | |
| width: 100%; | |
| border-collapse: separate; | |
| border-spacing: 0; | |
| margin-bottom: 30px; | |
| background-color: var(--bg-dark); | |
| border-radius: 8px; | |
| overflow: hidden; | |
| }} | |
| th, td {{ | |
| border: 1px solid var(--border-color); | |
| padding: 12px; | |
| text-align: left; | |
| transition: background-color 0.3s ease; | |
| }} | |
| th {{ | |
| background-color: #1e1e1e; | |
| color: var(--text-primary); | |
| font-weight: 600; | |
| text-transform: uppercase; | |
| font-size: 0.9em; | |
| }} | |
| tr:nth-child(even) {{ | |
| background-color: rgba(255,255,255,0.05); | |
| }} | |
| tr:hover {{ | |
| background-color: rgba(62,100,255,0.1); | |
| }} | |
| @media (max-width: 768px) {{ | |
| .container {{ | |
| padding: 15px; | |
| }} | |
| table {{ | |
| font-size: 0.9em; | |
| }} | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="logo"> | |
| <img src="" alt="Lokai AI Logo"> | |
| <h1>Lokiai AI</h1> | |
| </div> | |
| <div class="total-requests"> | |
| Total API Requests: {usage_data['total_requests']} | |
| </div> | |
| <h2>Model Usage</h2> | |
| <table> | |
| <tr> | |
| <th>Model</th> | |
| <th>Total Requests</th> | |
| <th>First Used</th> | |
| <th>Last Used</th> | |
| </tr> | |
| {model_usage_rows} | |
| </table> | |
| <h2>API Endpoint Usage</h2> | |
| <table> | |
| <tr> | |
| <th>Endpoint</th> | |
| <th>Total Requests</th> | |
| <th>First Used</th> | |
| <th>Last Used</th> | |
| </tr> | |
| {api_usage_rows} | |
| </table> | |
| <h2>Daily Usage (Last 7 Days)</h2> | |
| <table> | |
| <tr> | |
| <th>Date</th> | |
| <th>Entity</th> | |
| <th>Requests</th> | |
| </tr> | |
| {daily_usage_rows} | |
| </table> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def get_meme(): | |
| try: | |
| response = requests.get("https://meme-api.com/gimme") | |
| response_data = response.json() | |
| meme_url = response_data.get("url") | |
| if meme_url: | |
| def stream_image(): | |
| with requests.get(meme_url, stream=True) as image_response: | |
| for chunk in image_response.iter_content(chunk_size=1024): | |
| yield chunk | |
| return StreamingResponse(stream_image(), media_type="image/png") | |
| else: | |
| raise HTTPException(status_code=404, detail="No mimi found :(") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def startup_event(): | |
| global available_model_ids | |
| available_model_ids = load_model_ids("models.json") | |
| print(f"Loaded model IDs: {available_model_ids}") | |
| print("API endpoints:") | |
| print("GET /") | |
| print("GET /models") | |
| print("GET /searchgpt") | |
| print("POST /chat/completions") | |
| print("GET /images/generations") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |