Spaces:
Sleeping
Sleeping
File size: 4,727 Bytes
e619571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import sys
import os
import pandas as pd
import json
from pathlib import Path
import psycopg2
from dotenv import load_dotenv
import time
from tqdm import tqdm
sys.path.append(str(Path(__file__).parent.parent))
from main import get_gemini_response, sql_prompt
def validate_sql_query(query, conn):
"""Validate if the SQL query is syntactically correct"""
try:
with conn.cursor() as cursor:
# Reset any aborted transaction
conn.rollback()
# Now try to validate the query
cursor.execute("EXPLAIN " + query)
conn.commit() # Commit the successful EXPLAIN
return True, None
except psycopg2.Error as e:
# Rollback on error
conn.rollback()
return False, str(e)
def handle_api_error(error):
"""Handle different types of API errors"""
if "429" in str(error):
return "API quota exceeded", 30 # Wait 30 seconds
return str(error), 0
def run_query_tests():
load_dotenv()
# Database connection
conn = psycopg2.connect(
dbname=os.getenv('DB_NAME'),
user=os.getenv('DB_USER'),
password=os.getenv('DB_PASSWORD'),
host=os.getenv('DB_HOST'),
port=os.getenv('DB_PORT', '5432')
)
# Read the test dataset with encoding specification
csv_path = Path(__file__).parent / 'Pagila Evals Dataset(Sheet1).csv'
try:
test_data = pd.read_csv(csv_path, encoding='cp1252')
except UnicodeDecodeError:
# Fallback to latin-1 if cp1252 fails
test_data = pd.read_csv(csv_path, encoding='latin-1')
# Clean up any special quotes in the queries
test_data['Natural Language Query'] = test_data['Natural Language Query'].str.replace('"', '"').str.replace('"', '"')
results_dir = Path(__file__).parent / 'results'
results_dir.mkdir(exist_ok=True)
# Load existing results if any
output_file = results_dir / 'query_results.json'
if output_file.exists():
with open(output_file, 'r', encoding='utf-8') as f:
results = json.load(f)
else:
results = {}
# Process queries with progress bar
for _, row in tqdm(test_data.iterrows(), total=len(test_data), desc="Processing queries"):
query_num = str(row['Query Number'])
# Skip if already processed successfully
if query_num in results and results[query_num]['sql_query'] and results[query_num]['is_valid']:
continue
nl_query = row['Natural Language Query']
difficulty = row['Difficulty']
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
sql_query = get_gemini_response(nl_query, sql_prompt)
sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
is_valid, error_msg = validate_sql_query(sql_query, conn)
results[query_num] = {
'natural_language_query': nl_query,
'sql_query': sql_query,
'difficulty': difficulty,
'is_valid': is_valid,
'error': error_msg
}
# Save progress after each successful query
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
break # Success, exit retry loop
except Exception as e:
error_msg, wait_time = handle_api_error(e)
retry_count += 1
if wait_time > 0:
print(f"\nAPI quota exceeded. Waiting {wait_time} seconds...")
time.sleep(wait_time)
if retry_count == max_retries:
results[query_num] = {
'natural_language_query': nl_query,
'sql_query': None,
'difficulty': difficulty,
'is_valid': False,
'error': error_msg
}
# Save progress even for failed queries
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
conn.close()
print(f"\nResults saved to {output_file}")
if __name__ == "__main__":
run_query_tests()
|