import sys import os import pandas as pd import json from pathlib import Path import psycopg2 from dotenv import load_dotenv import time from tqdm import tqdm sys.path.append(str(Path(__file__).parent.parent)) from main import get_gemini_response, sql_prompt def validate_sql_query(query, conn): """Validate if the SQL query is syntactically correct""" try: with conn.cursor() as cursor: # Reset any aborted transaction conn.rollback() # Now try to validate the query cursor.execute("EXPLAIN " + query) conn.commit() # Commit the successful EXPLAIN return True, None except psycopg2.Error as e: # Rollback on error conn.rollback() return False, str(e) def handle_api_error(error): """Handle different types of API errors""" if "429" in str(error): return "API quota exceeded", 30 # Wait 30 seconds return str(error), 0 def run_query_tests(): load_dotenv() # Database connection conn = psycopg2.connect( dbname=os.getenv('DB_NAME'), user=os.getenv('DB_USER'), password=os.getenv('DB_PASSWORD'), host=os.getenv('DB_HOST'), port=os.getenv('DB_PORT', '5432') ) # Read the test dataset with encoding specification csv_path = Path(__file__).parent / 'Pagila Evals Dataset(Sheet1).csv' try: test_data = pd.read_csv(csv_path, encoding='cp1252') except UnicodeDecodeError: # Fallback to latin-1 if cp1252 fails test_data = pd.read_csv(csv_path, encoding='latin-1') # Clean up any special quotes in the queries test_data['Natural Language Query'] = test_data['Natural Language Query'].str.replace('"', '"').str.replace('"', '"') results_dir = Path(__file__).parent / 'results' results_dir.mkdir(exist_ok=True) # Load existing results if any output_file = results_dir / 'query_results.json' if output_file.exists(): with open(output_file, 'r', encoding='utf-8') as f: results = json.load(f) else: results = {} # Process queries with progress bar for _, row in tqdm(test_data.iterrows(), total=len(test_data), desc="Processing queries"): query_num = str(row['Query Number']) # Skip if already processed successfully if query_num in results and results[query_num]['sql_query'] and results[query_num]['is_valid']: continue nl_query = row['Natural Language Query'] difficulty = row['Difficulty'] max_retries = 3 retry_count = 0 while retry_count < max_retries: try: sql_query = get_gemini_response(nl_query, sql_prompt) sql_query = sql_query.replace('```sql', '').replace('```', '').strip() is_valid, error_msg = validate_sql_query(sql_query, conn) results[query_num] = { 'natural_language_query': nl_query, 'sql_query': sql_query, 'difficulty': difficulty, 'is_valid': is_valid, 'error': error_msg } # Save progress after each successful query with open(output_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) break # Success, exit retry loop except Exception as e: error_msg, wait_time = handle_api_error(e) retry_count += 1 if wait_time > 0: print(f"\nAPI quota exceeded. Waiting {wait_time} seconds...") time.sleep(wait_time) if retry_count == max_retries: results[query_num] = { 'natural_language_query': nl_query, 'sql_query': None, 'difficulty': difficulty, 'is_valid': False, 'error': error_msg } # Save progress even for failed queries with open(output_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) conn.close() print(f"\nResults saved to {output_file}") if __name__ == "__main__": run_query_tests()