# main.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import psycopg2
from psycopg2.extras import RealDictCursor
import os
from dotenv import load_dotenv
import google.generativeai as genai

app = FastAPI()

# Load environment variables and configure Genai
load_dotenv()
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))

class Query(BaseModel):
    question: str

def get_gemini_response(question, prompt):
    model = genai.GenerativeModel('gemini-1.5-pro')
    response = model.generate_content([prompt, question])
    return response.text.strip()  # Added strip() to remove any extra whitespace

sql_prompt = """
Convert the following English question to a PostgreSQL query for the Pagila DVD rental database.
Only return the SQL query without any markdown formatting or explanations.
The database has these main tables:
- actor (actor_id, first_name, last_name)
- film (film_id, title, description, release_year, rental_rate, length, rating)
- category (category_id, name)
- film_category (film_id, category_id)
- inventory (inventory_id, film_id, store_id)
- rental (rental_id, rental_date, inventory_id, customer_id, return_date, staff_id)
- customer (customer_id, first_name, last_name, email)
- payment (payment_id, customer_id, staff_id, rental_id, amount, payment_date)

Example queries:
Q: List all actors
A: SELECT * FROM actor;

Q: Show top 10 most rented movies
A: SELECT f.title, COUNT(r.rental_id) as rental_count FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id GROUP BY f.title ORDER BY rental_count DESC LIMIT 10;
"""

def execute_sql_query(query):
    conn = None
    try:
        conn = psycopg2.connect(
            dbname=os.getenv('DB_NAME'),
            user=os.getenv('DB_USER'),
            password=os.getenv('DB_PASSWORD'),
            host=os.getenv('DB_HOST'),
            port=os.getenv('DB_PORT', '5432')
        )
        with conn.cursor(cursor_factory=RealDictCursor) as cursor:
            cursor.execute(query)
            result = cursor.fetchall()
            return result
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Database Error: {str(e)}")
    finally:
        if conn:
            conn.close()

@app.post("/query")
async def process_query(query: Query):
    try:
        sql_query = get_gemini_response(query.question, sql_prompt)
        # Remove any SQL code block markers if present
        sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
        result = execute_sql_query(sql_query)
        return {"query": sql_query, "result": result}
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))