import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Optional, Tuple
import re

# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"

# Hugging Face Space path
DB_PATH = "ecommerce.db"

# Ensure dataset exists
if not os.path.exists(DB_PATH):
    os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db")  # Replace with actual dataset link

# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")

# Function: Fetch database schema
def fetch_schema(db_path: str) -> str:
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    schema = ""
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()
        schema += f"Table: {table_name}\n"
        for column in columns:
            schema += f"  Column: {column[1]}, Type: {column[2]}\n"
    conn.close()
    return schema

# Function: Extract SQL query from LLM response
def extract_sql_query(response: str) -> str:
    # Use regex to find content between ```sql and ```
    match = re.search(r"```sql(.*?)```", response, re.DOTALL)
    if match:
        return match.group(1).strip()  # Extract and return the SQL query
    return response  # Fallback: return the entire response if no SQL block is found

# Function: Convert text to SQL
def text_to_sql(query: str, schema: str) -> str:
    prompt = (
        "You are an SQL expert. Given the following database schema:\n\n"
        f"{schema}\n\n"
        "Convert the following query into SQL:\n\n"
        f"Query: {query}\n"
        "SQL:"
    )
    try:
        response = openai_client.chat.completions.create(
            model=OPENROUTER_MODEL,
            messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
        )
        sql_response = response.choices[0].message.content.strip()
        return extract_sql_query(sql_response)  # Extract SQL query from the response
    except Exception as e:
        return f"Error: {e}"


def preprocess_sql_for_sqlite(sql_query: str) -> str:
    """
    Replace non-SQLite functions with SQLite-compatible equivalents.
    """
    sql_query = re.sub(r"\bMONTH\s*\(\s*([\w.]+)\s*\)", r"strftime('%m', \1)", sql_query)
    sql_query = re.sub(r"\bYEAR\s*\(\s*([\w.]+)\s*\)", r"strftime('%Y', \1)", sql_query)
    return sql_query

def execute_sql(sql_query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
    try:
        conn = sqlite3.connect(DB_PATH)
        sql_query = preprocess_sql_for_sqlite(sql_query)  # Convert to SQLite-compatible SQL
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        return df, None
    except Exception as e:
        return None, f"SQL Execution Error: {e}"


# Function: Generate Dynamic Visualization
def visualize_data(df: pd.DataFrame) -> Optional[str]:
    if df.empty or df.shape[1] < 2:
        return None

    plt.figure(figsize=(6, 4))
    sns.set_theme(style="darkgrid")

    # Detect numeric columns
    numeric_cols = df.select_dtypes(include=['number']).columns
    if len(numeric_cols) < 1:
        return None

    # Choose visualization type dynamically
    if len(numeric_cols) == 1:  # Single numeric column, assume it's a count metric
        sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
        plt.title(f"Distribution of {numeric_cols[0]}")
    elif len(numeric_cols) == 2:  # Two numeric columns, assume X-Y plot
        sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
        plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
    elif df.shape[0] < 10:  # If rows are few, prefer pie chart
        plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
        plt.title(f"Proportion of {numeric_cols[0]}")
    else:  # Default: Bar chart for categories + values
        sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
        plt.xticks(rotation=45)
        plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")

    plt.tight_layout()
    plt.savefig("chart.png")
    return "chart.png"

# Gradio UI
def gradio_ui(query: str) -> Tuple[str, str, Optional[str]]:
    schema = fetch_schema(DB_PATH)
    sql_query = text_to_sql(query, schema)
    df, error = execute_sql(sql_query)
    if error:
        return sql_query, error, None
    visualization = visualize_data(df) if df is not None else None
    return sql_query, df.to_string(index=False), visualization

# Launch Gradio App
with gr.Blocks() as demo:
    gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
    query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
    submit_btn = gr.Button("Convert & Execute")
    sql_output = gr.Textbox(label="Generated SQL Query")
    table_output = gr.Textbox(label="Query Results")
    chart_output = gr.Image(label="Data Visualization")

    submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])

demo.launch()