File size: 5,497 Bytes
b480321
 
3ddc773
 
 
 
b480321
79ceb52
bc83352
b480321
3ddc773
b480321
 
 
08e4afd
 
b480321
350e55d
08e4afd
 
3ddc773
 
 
 
79ceb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b480321
bc83352
 
 
 
 
 
 
 
08e4afd
79ceb52
 
 
 
 
 
 
 
b480321
 
 
08e4afd
b480321
bc83352
 
b480321
 
 
bc83352
3b6af07
 
 
 
 
 
 
 
79ceb52
3ddc773
08e4afd
3b6af07
3ddc773
 
79ceb52
3ddc773
79ceb52
3ddc773
3b6af07
08e4afd
79ceb52
3ddc773
 
 
79ceb52
 
 
08e4afd
3ddc773
 
 
 
08e4afd
 
3ddc773
 
08e4afd
3ddc773
 
08e4afd
3ddc773
 
08e4afd
3ddc773
 
 
 
 
 
 
 
b480321
79ceb52
 
 
 
 
 
 
 
 
 
3ddc773
08e4afd
 
3ddc773
 
 
 
 
 
 
79ceb52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Optional, Tuple
import re

# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"

# Hugging Face Space path
DB_PATH = "ecommerce.db"

# Ensure dataset exists
if not os.path.exists(DB_PATH):
    os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db")  # Replace with actual dataset link

# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")

# Function: Fetch database schema
def fetch_schema(db_path: str) -> str:
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    schema = ""
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()
        schema += f"Table: {table_name}\n"
        for column in columns:
            schema += f"  Column: {column[1]}, Type: {column[2]}\n"
    conn.close()
    return schema

# Function: Extract SQL query from LLM response
def extract_sql_query(response: str) -> str:
    # Use regex to find content between ```sql and ```
    match = re.search(r"```sql(.*?)```", response, re.DOTALL)
    if match:
        return match.group(1).strip()  # Extract and return the SQL query
    return response  # Fallback: return the entire response if no SQL block is found

# Function: Convert text to SQL
def text_to_sql(query: str, schema: str) -> str:
    prompt = (
        "You are an SQL expert. Given the following database schema:\n\n"
        f"{schema}\n\n"
        "Convert the following query into SQL:\n\n"
        f"Query: {query}\n"
        "SQL:"
    )
    try:
        response = openai_client.chat.completions.create(
            model=OPENROUTER_MODEL,
            messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
        )
        sql_response = response.choices[0].message.content.strip()
        return extract_sql_query(sql_response)  # Extract SQL query from the response
    except Exception as e:
        return f"Error: {e}"


def preprocess_sql_for_sqlite(sql_query: str) -> str:
    """
    Replace non-SQLite functions with SQLite-compatible equivalents.
    """
    sql_query = re.sub(r"\bMONTH\s*\(\s*([\w.]+)\s*\)", r"strftime('%m', \1)", sql_query)
    sql_query = re.sub(r"\bYEAR\s*\(\s*([\w.]+)\s*\)", r"strftime('%Y', \1)", sql_query)
    return sql_query

def execute_sql(sql_query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
    try:
        conn = sqlite3.connect(DB_PATH)
        sql_query = preprocess_sql_for_sqlite(sql_query)  # Convert to SQLite-compatible SQL
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        return df, None
    except Exception as e:
        return None, f"SQL Execution Error: {e}"


# Function: Generate Dynamic Visualization
def visualize_data(df: pd.DataFrame) -> Optional[str]:
    if df.empty or df.shape[1] < 2:
        return None

    plt.figure(figsize=(6, 4))
    sns.set_theme(style="darkgrid")

    # Detect numeric columns
    numeric_cols = df.select_dtypes(include=['number']).columns
    if len(numeric_cols) < 1:
        return None

    # Choose visualization type dynamically
    if len(numeric_cols) == 1:  # Single numeric column, assume it's a count metric
        sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
        plt.title(f"Distribution of {numeric_cols[0]}")
    elif len(numeric_cols) == 2:  # Two numeric columns, assume X-Y plot
        sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
        plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
    elif df.shape[0] < 10:  # If rows are few, prefer pie chart
        plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
        plt.title(f"Proportion of {numeric_cols[0]}")
    else:  # Default: Bar chart for categories + values
        sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
        plt.xticks(rotation=45)
        plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")

    plt.tight_layout()
    plt.savefig("chart.png")
    return "chart.png"

# Gradio UI
def gradio_ui(query: str) -> Tuple[str, str, Optional[str]]:
    schema = fetch_schema(DB_PATH)
    sql_query = text_to_sql(query, schema)
    df, error = execute_sql(sql_query)
    if error:
        return sql_query, error, None
    visualization = visualize_data(df) if df is not None else None
    return sql_query, df.to_string(index=False), visualization

# Launch Gradio App
with gr.Blocks() as demo:
    gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
    query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
    submit_btn = gr.Button("Convert & Execute")
    sql_output = gr.Textbox(label="Generated SQL Query")
    table_output = gr.Textbox(label="Query Results")
    chart_output = gr.Image(label="Data Visualization")

    submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])

demo.launch()