import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse  
from pydantic import BaseModel
import httpx
from pathlib import Path  # Import Path from pathlib
import requests
import re
import cloudscraper
import json
from typing import Optional
import datetime
load_dotenv() #idk why this shi
app = FastAPI()
# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS') #deprecated -_-
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3')  # New endpoint for searchgpt
image_endpoint = os.getenv("IMAGE_ENDPOINT")
# Validate if the main secret API endpoints are set
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
    raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
# Define models that should use the secondary endpoint
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
available_model_ids = []
class Payload(BaseModel):
    model: str
    messages: list
    stream: bool
@app.get("/favicon.ico")
async def favicon():
    # The favicon.ico file is in the same directory as the app
    favicon_path = Path(__file__).parent / "favicon.ico"
    return FileResponse(favicon_path, media_type="image/x-icon")
    
def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> str:
    headers = {"User-Agent": ""}
    
    # Use the provided system prompt, or default to "Be Helpful and Friendly"
    system_message = systemprompt or "Be Helpful and Friendly"
    
    # Create the prompt history with the user query and system message
    prompt = [
        {"role": "user", "content": query},
    ]
    
    prompt.insert(0, {"content": system_message, "role": "system"})
    
    # Prepare the payload for the API request
    payload = {
        "is_vscode_extension": True,
        "message_history": prompt,
        "requested_model": "searchgpt",
        "user_input": prompt[-1]["content"],
    }
    
    # Send the request to the chat endpoint
    response = requests.post(secret_api_endpoint_3, headers=headers, json=payload, stream=True)
    
    streaming_text = ""
    
    # Process the streaming response
    for value in response.iter_lines(decode_unicode=True):
        if value.startswith("data: "):  
            try:
                json_modified_value = json.loads(value[6:])
                content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
                if content.strip():  # Only process non-empty content
                    cleaned_response = {
                        "created": json_modified_value.get("created"),
                        "id": json_modified_value.get("id"),
                        "model": "searchgpt",
                        "object": "chat.completion",
                        "choices": [
                            {
                                "message": {
                                    "content": content
                                }
                            }
                        ]
                    }
                    
                    if stream:
                        yield f"data: {json.dumps(cleaned_response)}\n\n"
                    
                    streaming_text += content
            except json.JSONDecodeError:
                continue
    
    if not stream:
        yield streaming_text
@app.get("/searchgpt")
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
    if not q:
        raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
    
    if stream:
        return StreamingResponse(
            generate_search(q, systemprompt=systemprompt, stream=True),
            media_type="text/event-stream"
        )
    else:
        # For non-streaming, collect the text and return as JSON response
        response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
        return JSONResponse(content={"response": response_text})
@app.get("/", response_class=HTMLResponse)
async def root():
    # Open and read the content of index.html (in the same folder as the app)
    file_path = "index.html"
    try:
        with open(file_path, "r") as file:
            html_content = file.read()
        return HTMLResponse(content=html_content)
    except FileNotFoundError:
        return HTMLResponse(content="
File not found
", status_code=404)
async def get_models():
    try:
        # Load the models from models.json in the same folder
        file_path = Path(__file__).parent / 'models.json'
        with open(file_path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        raise HTTPException(status_code=404, detail="models.json not found")
    except json.JSONDecodeError:
        raise HTTPException(status_code=500, detail="Error decoding models.json")
@app.get("/models")
async def fetch_models():
    return await get_models()
server_status = True #working
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def get_completion(payload: Payload, request: Request):
    # Check server status
    if not server_status:
        return JSONResponse(
            status_code=503,
            content={"message": "Server is down. Please try again later."}
        )
    
    model_to_use = payload.model if payload.model else "gpt-4o-mini"
    
    # Validate model availability
    if model_to_use not in available_model_ids:
        raise HTTPException(
            status_code=400,
            detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
        )
    
    # Proceed with the request handling
    payload_dict = payload.dict()
    payload_dict["model"] = model_to_use
    
    # Select the appropriate endpoint
    endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
    current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
    aaip = request.client.host
    print(f"Time: {current_time}, {aaip}")
    print(payload_dict)
    
    async def stream_generator(payload_dict):
        scraper = cloudscraper.create_scraper()  # Create a CloudScraper session
        try:
            # Send POST request using CloudScraper
            response = scraper.post(f"{endpoint}/v1/chat/completions", json=payload_dict, stream=True)
            
            # Check response status
            if response.status_code == 422:
                raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
            elif response.status_code == 400:
                raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
            elif response.status_code == 403:
                raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
            elif response.status_code == 404:
                raise HTTPException(status_code=404, detail="The requested resource was not found.")
            elif response.status_code >= 500:
                raise HTTPException(status_code=500, detail="Server error. Try again later.")
            
            # Stream response lines to the client
            for line in response.iter_lines():
                if line:
                    yield line.decode('utf-8') + "\n"
        except requests.exceptions.RequestException as req_err:
            # Handle request-specific errors
            print(response.text)
            raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
        except Exception as e:
            # Handle unexpected errors
            print(response.text)
            raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
    
    return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
# Remove the duplicated endpoint and combine the functionality
@app.api_route("/images/generations", methods=["GET", "POST"])  # Support both GET and POST
async def generate_image(
    prompt: Optional[str] = None, 
    model: str = "flux",  # Default model
    seed: Optional[int] = None,
    width: Optional[int] = None,
    height: Optional[int] = None,
    nologo: Optional[bool] = True,
    private: Optional[bool] = None,
    enhance: Optional[bool] = None,
    request: Request = None,  # Access raw POST data
):
    """
    Generate an image using the Image Generation API.
    """
    # Validate the image endpoint
    if not image_endpoint:
        raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
    # Handle GET and POST prompts
    if request.method == "POST":
        try:
            body = await request.json()  # Parse JSON body
            prompt = body.get("prompt", "").strip()
            if not prompt:
                raise HTTPException(status_code=400, detail="Prompt cannot be empty")
        except Exception:
            raise HTTPException(status_code=400, detail="Invalid JSON payload")
    elif request.method == "GET":
        if not prompt or not prompt.strip():
            raise HTTPException(status_code=400, detail="Prompt cannot be empty")
        prompt = prompt.strip()
    # Sanitize and encode the prompt
    encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
    # Construct the URL with the encoded prompt
    base_url = image_endpoint.rstrip('/')  # Remove trailing slash if present
    url = f"{base_url}/{encoded_prompt}"
    # Prepare query parameters with validation
    params = {}
    if model and isinstance(model, str):
        params['model'] = model
    if seed is not None and isinstance(seed, int):
        params['seed'] = seed
    if width is not None and isinstance(width, int) and 64 <= width <= 2048:
        params['width'] = width
    if height is not None and isinstance(height, int) and 64 <= height <= 2048:
        params['height'] = height
    if nologo is not None:
        params['nologo'] = str(nologo).lower()
    if private is not None:
        params['private'] = str(private).lower()
    if enhance is not None:
        params['enhance'] = str(enhance).lower()
    try:
        timeout = httpx.Timeout(60.0)  # Set a reasonable timeout
        async with httpx.AsyncClient(timeout=timeout) as client:
            response = await client.get(url, params=params, follow_redirects=True)
            # Check for various error conditions
            if response.status_code == 404:
                raise HTTPException(status_code=404, detail="Image generation service not found")
            elif response.status_code == 400:
                raise HTTPException(status_code=400, detail="Invalid parameters provided to image service")
            elif response.status_code == 429:
                raise HTTPException(status_code=429, detail="Too many requests to image service")
            elif response.status_code != 200:
                raise HTTPException(
                    status_code=response.status_code,
                    detail=f"Image generation failed with status code {response.status_code}"
                )
            # Verify content type
            content_type = response.headers.get('content-type', '')
            if not content_type.startswith('image/'):
                raise HTTPException(
                    status_code=500,
                    detail=f"Unexpected content type received: {content_type}"
                )
            return StreamingResponse(
                response.iter_bytes(),
                media_type=content_type,
                headers={
                    'Cache-Control': 'no-cache',
                    'Pragma': 'no-cache'
                }
            )
    except httpx.TimeoutException:
        raise HTTPException(status_code=504, detail="Image generation request timed out")
    except httpx.RequestError as e:
        raise HTTPException(status_code=500, detail=f"Failed to contact image service: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Unexpected error during image generation: {str(e)}")
@app.get("/playground", response_class=HTMLResponse)
async def playground():
    # Open and read the content of playground.html (in the same folder as the app)
    file_path = "playground.html"
    try:
        with open(file_path, "r") as file:
            html_content = file.read()
        return HTMLResponse(content=html_content)
    except FileNotFoundError:
        return HTMLResponse(content="playground.html not found
", status_code=404)
        
def load_model_ids(json_file_path):
    try:
        with open(json_file_path, 'r') as f:
            models_data = json.load(f)
            # Extract 'id' from each model object
            model_ids = [model['id'] for model in models_data if 'id' in model]
            return model_ids
    except FileNotFoundError:
        print("Error: models.json file not found.")
        return []
    except json.JSONDecodeError:
        print("Error: Invalid JSON format in models.json.")
        return []
@app.on_event("startup")
async def startup_event():
    global available_model_ids
    available_model_ids = load_model_ids("models.json")
    print(f"Loaded model IDs: {available_model_ids}")
    print("API endpoints:")
    print("GET /")
    print("GET /models")
    print("GET /searchgpt")
    print("POST /chat/completions")
    print("GET /images/generations") 
    
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)