import sys
import os
import pandas as pd
import json
from pathlib import Path
import psycopg2
from dotenv import load_dotenv
import time
from tqdm import tqdm

sys.path.append(str(Path(__file__).parent.parent))
from main import get_gemini_response, sql_prompt

def validate_sql_query(query, conn):
    """Validate if the SQL query is syntactically correct"""
    try:
        with conn.cursor() as cursor:
            # Reset any aborted transaction
            conn.rollback()
            
            # Now try to validate the query
            cursor.execute("EXPLAIN " + query)
            conn.commit()  # Commit the successful EXPLAIN
            return True, None
    except psycopg2.Error as e:
        # Rollback on error
        conn.rollback()
        return False, str(e)

def handle_api_error(error):
    """Handle different types of API errors"""
    if "429" in str(error):
        return "API quota exceeded", 30  # Wait 30 seconds
    return str(error), 0

def run_query_tests():
    load_dotenv()
    
    # Database connection
    conn = psycopg2.connect(
        dbname=os.getenv('DB_NAME'),
        user=os.getenv('DB_USER'),
        password=os.getenv('DB_PASSWORD'),
        host=os.getenv('DB_HOST'),
        port=os.getenv('DB_PORT', '5432')
    )
    
    # Read the test dataset with encoding specification
    csv_path = Path(__file__).parent / 'Pagila Evals Dataset(Sheet1).csv'
    try:
        test_data = pd.read_csv(csv_path, encoding='cp1252')
    except UnicodeDecodeError:
        # Fallback to latin-1 if cp1252 fails
        test_data = pd.read_csv(csv_path, encoding='latin-1')
    
    # Clean up any special quotes in the queries
    test_data['Natural Language Query'] = test_data['Natural Language Query'].str.replace('"', '"').str.replace('"', '"')
    
    results_dir = Path(__file__).parent / 'results'
    results_dir.mkdir(exist_ok=True)
    
    # Load existing results if any
    output_file = results_dir / 'query_results.json'
    if output_file.exists():
        with open(output_file, 'r', encoding='utf-8') as f:
            results = json.load(f)
    else:
        results = {}
    
    # Process queries with progress bar
    for _, row in tqdm(test_data.iterrows(), total=len(test_data), desc="Processing queries"):
        query_num = str(row['Query Number'])
        
        # Skip if already processed successfully
        if query_num in results and results[query_num]['sql_query'] and results[query_num]['is_valid']:
            continue
            
        nl_query = row['Natural Language Query']
        difficulty = row['Difficulty']
        
        max_retries = 3
        retry_count = 0
        
        while retry_count < max_retries:
            try:
                sql_query = get_gemini_response(nl_query, sql_prompt)
                sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
                
                is_valid, error_msg = validate_sql_query(sql_query, conn)
                
                results[query_num] = {
                    'natural_language_query': nl_query,
                    'sql_query': sql_query,
                    'difficulty': difficulty,
                    'is_valid': is_valid,
                    'error': error_msg
                }
                
                # Save progress after each successful query
                with open(output_file, 'w', encoding='utf-8') as f:
                    json.dump(results, f, indent=2, ensure_ascii=False)
                
                break  # Success, exit retry loop
                
            except Exception as e:
                error_msg, wait_time = handle_api_error(e)
                retry_count += 1
                
                if wait_time > 0:
                    print(f"\nAPI quota exceeded. Waiting {wait_time} seconds...")
                    time.sleep(wait_time)
                
                if retry_count == max_retries:
                    results[query_num] = {
                        'natural_language_query': nl_query,
                        'sql_query': None,
                        'difficulty': difficulty,
                        'is_valid': False,
                        'error': error_msg
                    }
                    
                    # Save progress even for failed queries
                    with open(output_file, 'w', encoding='utf-8') as f:
                        json.dump(results, f, indent=2, ensure_ascii=False)
    
    conn.close()
    print(f"\nResults saved to {output_file}")

if __name__ == "__main__":
    run_query_tests()