File size: 4,727 Bytes
e619571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
import os
import pandas as pd
import json
from pathlib import Path
import psycopg2
from dotenv import load_dotenv
import time
from tqdm import tqdm

sys.path.append(str(Path(__file__).parent.parent))
from main import get_gemini_response, sql_prompt

def validate_sql_query(query, conn):
    """Validate if the SQL query is syntactically correct"""
    try:
        with conn.cursor() as cursor:
            # Reset any aborted transaction
            conn.rollback()
            
            # Now try to validate the query
            cursor.execute("EXPLAIN " + query)
            conn.commit()  # Commit the successful EXPLAIN
            return True, None
    except psycopg2.Error as e:
        # Rollback on error
        conn.rollback()
        return False, str(e)

def handle_api_error(error):
    """Handle different types of API errors"""
    if "429" in str(error):
        return "API quota exceeded", 30  # Wait 30 seconds
    return str(error), 0

def run_query_tests():
    load_dotenv()
    
    # Database connection
    conn = psycopg2.connect(
        dbname=os.getenv('DB_NAME'),
        user=os.getenv('DB_USER'),
        password=os.getenv('DB_PASSWORD'),
        host=os.getenv('DB_HOST'),
        port=os.getenv('DB_PORT', '5432')
    )
    
    # Read the test dataset with encoding specification
    csv_path = Path(__file__).parent / 'Pagila Evals Dataset(Sheet1).csv'
    try:
        test_data = pd.read_csv(csv_path, encoding='cp1252')
    except UnicodeDecodeError:
        # Fallback to latin-1 if cp1252 fails
        test_data = pd.read_csv(csv_path, encoding='latin-1')
    
    # Clean up any special quotes in the queries
    test_data['Natural Language Query'] = test_data['Natural Language Query'].str.replace('"', '"').str.replace('"', '"')
    
    results_dir = Path(__file__).parent / 'results'
    results_dir.mkdir(exist_ok=True)
    
    # Load existing results if any
    output_file = results_dir / 'query_results.json'
    if output_file.exists():
        with open(output_file, 'r', encoding='utf-8') as f:
            results = json.load(f)
    else:
        results = {}
    
    # Process queries with progress bar
    for _, row in tqdm(test_data.iterrows(), total=len(test_data), desc="Processing queries"):
        query_num = str(row['Query Number'])
        
        # Skip if already processed successfully
        if query_num in results and results[query_num]['sql_query'] and results[query_num]['is_valid']:
            continue
            
        nl_query = row['Natural Language Query']
        difficulty = row['Difficulty']
        
        max_retries = 3
        retry_count = 0
        
        while retry_count < max_retries:
            try:
                sql_query = get_gemini_response(nl_query, sql_prompt)
                sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
                
                is_valid, error_msg = validate_sql_query(sql_query, conn)
                
                results[query_num] = {
                    'natural_language_query': nl_query,
                    'sql_query': sql_query,
                    'difficulty': difficulty,
                    'is_valid': is_valid,
                    'error': error_msg
                }
                
                # Save progress after each successful query
                with open(output_file, 'w', encoding='utf-8') as f:
                    json.dump(results, f, indent=2, ensure_ascii=False)
                
                break  # Success, exit retry loop
                
            except Exception as e:
                error_msg, wait_time = handle_api_error(e)
                retry_count += 1
                
                if wait_time > 0:
                    print(f"\nAPI quota exceeded. Waiting {wait_time} seconds...")
                    time.sleep(wait_time)
                
                if retry_count == max_retries:
                    results[query_num] = {
                        'natural_language_query': nl_query,
                        'sql_query': None,
                        'difficulty': difficulty,
                        'is_valid': False,
                        'error': error_msg
                    }
                    
                    # Save progress even for failed queries
                    with open(output_file, 'w', encoding='utf-8') as f:
                        json.dump(results, f, indent=2, ensure_ascii=False)
    
    conn.close()
    print(f"\nResults saved to {output_file}")

if __name__ == "__main__":
    run_query_tests()