|
import os |
|
import sqlite3 |
|
import random |
|
import json |
|
import openai |
|
from datetime import timedelta |
|
from faker import Faker |
|
from langgraph.graph import StateGraph, START |
|
from typing import TypedDict, Optional |
|
import gradio as gr |
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
def init_db(): |
|
fake = Faker() |
|
conn = sqlite3.connect("complex_test_db.sqlite", timeout=20) |
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.executescript(""" |
|
DROP TABLE IF EXISTS order_items; |
|
DROP TABLE IF EXISTS orders; |
|
DROP TABLE IF EXISTS products; |
|
DROP TABLE IF EXISTS customers; |
|
DROP TABLE IF EXISTS payments; |
|
DROP TABLE IF EXISTS shipment; |
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
CREATE TABLE customers ( |
|
customer_id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
name TEXT NOT NULL, |
|
email TEXT UNIQUE NOT NULL, |
|
phone TEXT NOT NULL, |
|
address TEXT NOT NULL, |
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
); |
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
CREATE TABLE products ( |
|
product_id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
name TEXT NOT NULL, |
|
category TEXT NOT NULL, |
|
price DECIMAL(10,2) NOT NULL, |
|
stock_quantity INTEGER NOT NULL |
|
); |
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
CREATE TABLE orders ( |
|
order_id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
customer_id INTEGER NOT NULL, |
|
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
total_amount DECIMAL(10,2) NOT NULL, |
|
status TEXT CHECK(status IN ('Pending', 'Shipped', 'Delivered', 'Cancelled')) NOT NULL, |
|
FOREIGN KEY (customer_id) REFERENCES customers(customer_id) |
|
); |
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
CREATE TABLE order_items ( |
|
order_item_id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
order_id INTEGER NOT NULL, |
|
product_id INTEGER NOT NULL, |
|
quantity INTEGER NOT NULL, |
|
subtotal DECIMAL(10,2) NOT NULL, |
|
FOREIGN KEY (order_id) REFERENCES orders(order_id), |
|
FOREIGN KEY (product_id) REFERENCES products(product_id) |
|
); |
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
CREATE TABLE payments ( |
|
payment_id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
order_id INTEGER UNIQUE NOT NULL, |
|
payment_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
amount DECIMAL(10,2) NOT NULL, |
|
payment_method TEXT CHECK(payment_method IN ('Credit Card', 'Debit Card', 'PayPal', 'Bank Transfer')) NOT NULL, |
|
status TEXT CHECK(status IN ('Success', 'Failed', 'Pending')) NOT NULL, |
|
FOREIGN KEY (order_id) REFERENCES orders(order_id) |
|
); |
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
CREATE TABLE shipment ( |
|
shipment_id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
order_id INTEGER UNIQUE NOT NULL, |
|
shipment_date TIMESTAMP, |
|
delivery_date TIMESTAMP, |
|
carrier TEXT NOT NULL, |
|
tracking_number TEXT UNIQUE, |
|
status TEXT CHECK(status IN ('Processing', 'Shipped', 'Delivered')) NOT NULL, |
|
FOREIGN KEY (order_id) REFERENCES orders(order_id) |
|
); |
|
""") |
|
|
|
conn.commit() |
|
|
|
|
|
NUM_CUSTOMERS = 1000 |
|
NUM_PRODUCTS = 500 |
|
NUM_ORDERS = 2000 |
|
NUM_ORDER_ITEMS = 5000 |
|
NUM_PAYMENTS = NUM_ORDERS |
|
NUM_SHIPMENTS = int(NUM_ORDERS * 0.8) |
|
|
|
|
|
customers = [] |
|
unique_emails = set() |
|
while len(customers) < NUM_CUSTOMERS: |
|
name = fake.name() |
|
email = fake.email() |
|
phone = fake.phone_number() |
|
address = fake.address().replace("\n", ", ") |
|
if email not in unique_emails: |
|
customers.append((name, email, phone, address)) |
|
unique_emails.add(email) |
|
|
|
cursor.executemany(""" |
|
INSERT INTO customers (name, email, phone, address) |
|
VALUES (?, ?, ?, ?); |
|
""", customers) |
|
|
|
|
|
products = [] |
|
categories = ["Electronics", "Clothing", "Books", "Home Appliances", "Toys"] |
|
for _ in range(NUM_PRODUCTS): |
|
products.append((fake.word().capitalize(), random.choice(categories), round(random.uniform(5, 500), 2), random.randint(10, 500))) |
|
|
|
cursor.executemany(""" |
|
INSERT INTO products (name, category, price, stock_quantity) |
|
VALUES (?, ?, ?, ?); |
|
""", products) |
|
|
|
|
|
cursor.execute("SELECT customer_id FROM customers;") |
|
customer_ids = [row[0] for row in cursor.fetchall()] |
|
|
|
cursor.execute("SELECT product_id FROM products;") |
|
product_ids = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
orders = [] |
|
statuses = ["Pending", "Shipped", "Delivered", "Cancelled"] |
|
for _ in range(NUM_ORDERS): |
|
customer_id = random.choice(customer_ids) |
|
total_amount = round(random.uniform(20, 2000), 2) |
|
status = random.choice(statuses) |
|
order_date = fake.date_time_between(start_date="-1y", end_date="now") |
|
orders.append((customer_id, order_date, total_amount, status)) |
|
|
|
cursor.executemany(""" |
|
INSERT INTO orders (customer_id, order_date, total_amount, status) |
|
VALUES (?, ?, ?, ?); |
|
""", orders) |
|
|
|
|
|
cursor.execute("SELECT order_id FROM orders;") |
|
order_ids = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
order_items = [] |
|
for _ in range(NUM_ORDER_ITEMS): |
|
order_id = random.choice(order_ids) |
|
product_id = random.choice(product_ids) |
|
quantity = random.randint(1, 5) |
|
subtotal = round(quantity * random.uniform(5, 500), 2) |
|
order_items.append((order_id, product_id, quantity, subtotal)) |
|
|
|
cursor.executemany(""" |
|
INSERT INTO order_items (order_id, product_id, quantity, subtotal) |
|
VALUES (?, ?, ?, ?); |
|
""", order_items) |
|
|
|
|
|
payment_methods = ["Credit Card", "Debit Card", "PayPal", "Bank Transfer"] |
|
payment_statuses = ["Success", "Failed", "Pending"] |
|
payments = [] |
|
for order_id in order_ids[:NUM_PAYMENTS]: |
|
amount = round(random.uniform(20, 2000), 2) |
|
payment_method = random.choice(payment_methods) |
|
status = random.choice(payment_statuses) |
|
payments.append((order_id, fake.date_time_between(start_date="-1y", end_date="now"), amount, payment_method, status)) |
|
|
|
cursor.executemany(""" |
|
INSERT INTO payments (order_id, payment_date, amount, payment_method, status) |
|
VALUES (?, ?, ?, ?, ?); |
|
""", payments) |
|
|
|
|
|
carriers = ["FedEx", "UPS", "DHL", "USPS"] |
|
shipments = [] |
|
for order_id in order_ids[:NUM_SHIPMENTS]: |
|
shipment_date = fake.date_time_between(start_date="-1y", end_date="now") |
|
delivery_date = shipment_date + timedelta(days=random.randint(1, 10)) |
|
tracking_number = fake.uuid4() |
|
carrier = random.choice(carriers) |
|
status = random.choice(["Processing", "Shipped", "Delivered"]) |
|
shipments.append((order_id, shipment_date, delivery_date, carrier, tracking_number, status)) |
|
|
|
cursor.executemany(""" |
|
INSERT INTO shipment (order_id, shipment_date, delivery_date, carrier, tracking_number, status) |
|
VALUES (?, ?, ?, ?, ?, ?); |
|
""", shipments) |
|
|
|
conn.commit() |
|
cursor.close() |
|
conn.close() |
|
print("✅ Database setup complete with complex relationships and substantial data!") |
|
|
|
|
|
init_db() |
|
|
|
|
|
class SQLExecutionState(TypedDict): |
|
sql_query: str |
|
structured_metadata: Optional[dict] |
|
validation_result: Optional[dict] |
|
optimized_sql: Optional[str] |
|
execution_result: Optional[dict] |
|
|
|
|
|
graph = StateGraph(state_schema=SQLExecutionState) |
|
|
|
|
|
def query_understanding_agent(state: SQLExecutionState) -> SQLExecutionState: |
|
natural_language_query = state["sql_query"] |
|
prompt = f""" |
|
Convert the following natural language query into **structured SQL metadata** based on the database schema. |
|
If you cannot generate a query that adheres strictly to the schema, return: |
|
{{ "error": "Invalid query: Tables or columns do not match schema" }} |
|
|
|
**Query:** "{natural_language_query}" |
|
|
|
**Database Schema:** |
|
- **orders** (order_id, customer_id, order_date, total_amount, status) |
|
- **order_items** (order_item_id, order_id, product_id, quantity, subtotal) |
|
- **products** (product_id, name, category, price, stock_quantity) |
|
- **customers** (customer_id, name, email, phone, address, created_at) |
|
- **payments** (payment_id, order_id, payment_date, amount, payment_method, status) |
|
|
|
**Rules:** |
|
- Use only the provided tables. |
|
- Ensure correct column names. |
|
- Return output strictly in JSON format. |
|
- Group by relevant fields when necessary. |
|
|
|
**Example Output Format:** |
|
{json.dumps({ |
|
"operation": "SELECT", |
|
"columns": ["customer_id", "SUM(total_amount) AS total_spent"], |
|
"table": "orders", |
|
"conditions": ["order_date BETWEEN '2024-01-01' AND '2024-12-31'"], |
|
"group_by": ["customer_id"], |
|
"order_by": ["total_spent DESC"], |
|
"limit": 5 |
|
}, indent=4)} |
|
|
|
**DO NOT return explanations. Only return valid JSON.** |
|
""" |
|
response = openai.ChatCompletion.create( |
|
model="gpt-4o-mini", |
|
messages=[{"role": "user", "content": prompt}] |
|
) |
|
try: |
|
metadata = json.loads(response["choices"][0]["message"]["content"]) |
|
return {"structured_metadata": metadata} |
|
except json.JSONDecodeError: |
|
return {"structured_metadata": {"error": "Invalid JSON response from OpenAI"}} |
|
|
|
graph.add_node("Query Understanding", query_understanding_agent) |
|
|
|
|
|
def query_validation_agent(state: SQLExecutionState) -> SQLExecutionState: |
|
sql_metadata = state.get("structured_metadata", {}) |
|
if "error" in sql_metadata: |
|
return {"validation_result": {"error": sql_metadata["error"]}} |
|
query = sql_metadata.get("operation", "") |
|
restricted_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER"] |
|
if any(keyword in query.upper() for keyword in restricted_keywords): |
|
return {"validation_result": {"error": "Potentially harmful SQL operation detected!"}} |
|
return {"validation_result": {"valid": True}} |
|
|
|
graph.add_node("Query Validation", query_validation_agent) |
|
|
|
|
|
def query_optimization_agent(state: SQLExecutionState) -> SQLExecutionState: |
|
sql_metadata = state.get("structured_metadata", {}) |
|
prompt = f""" |
|
Optimize the following SQL query for performance while ensuring that the output includes only the required columns and necessary joins. |
|
Do not include any extra columns, unnecessary joins, or records that are not required to answer the query. |
|
|
|
Here is the original SQL metadata: |
|
{json.dumps(sql_metadata, indent=4)} |
|
|
|
Output only the final optimized SQL query in plain text without any markdown formatting or explanations. |
|
""" |
|
response = openai.ChatCompletion.create( |
|
model="gpt-4o-mini", |
|
messages=[{"role": "user", "content": prompt}], |
|
temperature=0 |
|
) |
|
optimized_query = response["choices"][0]["message"]["content"].strip() |
|
if optimized_query.startswith("```sql"): |
|
optimized_query = optimized_query.replace("```sql", "").replace("```", "").strip() |
|
return {"optimized_sql": optimized_query} |
|
|
|
graph.add_node("Query Optimization", query_optimization_agent) |
|
|
|
|
|
def execution_agent(state: SQLExecutionState) -> SQLExecutionState: |
|
query = state.get("optimized_sql", "").strip() |
|
if not query: |
|
return {"execution_result": {"error": "No SQL query to execute."}} |
|
try: |
|
conn = sqlite3.connect("complex_test_db.sqlite", timeout=20) |
|
cursor = conn.cursor() |
|
cursor.execute(query) |
|
result = cursor.fetchall() |
|
cursor.close() |
|
conn.close() |
|
if not result: |
|
return {"execution_result": {"error": "Query executed successfully but returned no results."}} |
|
return {"execution_result": result} |
|
except sqlite3.Error as e: |
|
return {"execution_result": {"error": str(e)}} |
|
|
|
graph.add_node("SQL Execution", execution_agent) |
|
|
|
|
|
graph.add_edge(START, "Query Understanding") |
|
graph.add_edge("Query Understanding", "Query Validation") |
|
graph.add_edge("Query Validation", "Query Optimization") |
|
graph.add_edge("Query Optimization", "SQL Execution") |
|
|
|
compiled_pipeline = graph.compile() |
|
|
|
|
|
def run_multi_agent_query(natural_language_query): |
|
result = compiled_pipeline.invoke({"sql_query": natural_language_query}) |
|
return json.dumps(result.get("execution_result", {}), indent=2) |
|
|
|
|
|
schema_description = """ |
|
**Database Schema:** |
|
|
|
- **customers**: customer_id, name, email, phone, address, created_at |
|
- **products**: product_id, name, category, price, stock_quantity |
|
- **orders**: order_id, customer_id, order_date, total_amount, status |
|
- **order_items**: order_item_id, order_id, product_id, quantity, subtotal |
|
- **payments**: payment_id, order_id, payment_date, amount, payment_method, status |
|
- **shipment**: shipment_id, order_id, shipment_date, delivery_date, carrier, tracking_number, status |
|
""" |
|
|
|
iface = gr.Interface( |
|
fn=run_multi_agent_query, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here.(Which products sold the most in 2024))."), |
|
outputs="text", |
|
title="Multi-Agent SQL Generator", |
|
description=("Enter a natural language query to generate and execute a SQL query. E.g. " |
|
"Find the email_id of the top 5 customers who spent the most in 2024.\n\n" |
|
+ schema_description) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |