Spaces:
Sleeping
Sleeping
# main.py | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import psycopg2 | |
from psycopg2.extras import RealDictCursor | |
import os | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
app = FastAPI() | |
# Load environment variables and configure Genai | |
load_dotenv() | |
genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) | |
class Query(BaseModel): | |
question: str | |
def get_gemini_response(question, prompt): | |
model = genai.GenerativeModel('gemini-1.5-pro') | |
response = model.generate_content([prompt, question]) | |
return response.text.strip() # Added strip() to remove any extra whitespace | |
sql_prompt = """ | |
Convert the following English question to a PostgreSQL query for the Pagila DVD rental database. | |
Only return the SQL query without any markdown formatting or explanations. | |
The database has these main tables: | |
- actor (actor_id, first_name, last_name) | |
- film (film_id, title, description, release_year, rental_rate, length, rating) | |
- category (category_id, name) | |
- film_category (film_id, category_id) | |
- inventory (inventory_id, film_id, store_id) | |
- rental (rental_id, rental_date, inventory_id, customer_id, return_date, staff_id) | |
- customer (customer_id, first_name, last_name, email) | |
- payment (payment_id, customer_id, staff_id, rental_id, amount, payment_date) | |
Example queries: | |
Q: List all actors | |
A: SELECT * FROM actor; | |
Q: Show top 10 most rented movies | |
A: SELECT f.title, COUNT(r.rental_id) as rental_count FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id GROUP BY f.title ORDER BY rental_count DESC LIMIT 10; | |
""" | |
def execute_sql_query(query): | |
conn = None | |
try: | |
conn = psycopg2.connect( | |
dbname=os.getenv('DB_NAME'), | |
user=os.getenv('DB_USER'), | |
password=os.getenv('DB_PASSWORD'), | |
host=os.getenv('DB_HOST'), | |
port=os.getenv('DB_PORT', '5432') | |
) | |
with conn.cursor(cursor_factory=RealDictCursor) as cursor: | |
cursor.execute(query) | |
result = cursor.fetchall() | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Database Error: {str(e)}") | |
finally: | |
if conn: | |
conn.close() | |
async def process_query(query: Query): | |
try: | |
sql_query = get_gemini_response(query.question, sql_prompt) | |
# Remove any SQL code block markers if present | |
sql_query = sql_query.replace('```sql', '').replace('```', '').strip() | |
result = execute_sql_query(sql_query) | |
return {"query": sql_query, "result": result} | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) |