import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from functools import lru_cache
import json
import mysql.connector
from mysql.connector import Error
import os
import sys
from datetime import datetime
import time
import logging
import threading

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
)

# Enable GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Database configuration
DB_CONFIG = {
    'host': 'sql12.freemysqlhosting.net',
    'database': 'sql12740625',
    'user': 'sql12740625',
    'password': 'QGG9kdrE4g',
    'port': 3306,
    'pool_size': 5,
    'pool_reset_session': True
}

# Global variables for model and tokenizer
GLOBAL_MODEL = None
GLOBAL_TOKENIZER = None
db_connection_status = False

def initialize_model():
    """Initialize model and tokenizer globally"""
    global GLOBAL_MODEL, GLOBAL_TOKENIZER
    logging.info("Initializing model and tokenizer...")
    st.write("Initializing model and tokenizer...")
    start_time = time.time()
    
    model_name_sql = "premai-io/prem-1B-SQL"
    GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained(model_name_sql)
    GLOBAL_MODEL = AutoModelForCausalLM.from_pretrained(
        model_name_sql,
        torch_dtype=torch.float32,  # Use float32 for CPU
    ).to(device)
    
    # Set model to evaluation mode
    GLOBAL_MODEL.eval()
    
    logging.info(f"Model initialization took {time.time() - start_time:.2f} seconds")

def test_db_connection():
    """Test database connection with timeout"""
    global db_connection_status
    try:
        logging.info("Testing database connection...")
        connection = mysql.connector.connect(
            **DB_CONFIG,
            connect_timeout=10
        )
        if connection.is_connected():
            db_info = connection.get_server_info()
            cursor = connection.cursor()
            cursor.execute("SELECT DATABASE();")
            db_name = cursor.fetchone()[0]
            cursor.close()
            connection.close()
            db_connection_status = True
            logging.info(f"Successfully connected to MySQL Server version {db_info} - Database: {db_name}")
            return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}"
    except Error as e:
        db_connection_status = False
        logging.error(f"Error connecting to MySQL database: {e}")
        return False, f"Error connecting to MySQL database: {e}"
    return False, "Unable to establish database connection"

def get_db_connection():
    """Get database connection from pool"""
    logging.info("Getting database connection from pool...")
    return mysql.connector.connect(**DB_CONFIG)

def execute_query(query):
    """Execute SQL query with timeout and connection pooling"""
    logging.info(f"Executing query: {query}")
    connection = None
    try:
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        cursor.execute(query)
        results = cursor.fetchall()
        logging.info(f"Query executed successfully, retrieved {len(results)} records.")
        return results
    except Error as e:
        logging.error(f"Error executing query: {e}")
        return f"Error executing query: {e}"
    finally:
        if connection and connection.is_connected():
            cursor.close()
            connection.close()
            logging.info("Database connection closed.")

def generate_sql(natural_language_query):
    """Generate SQL query with performance optimizations"""
    logging.info(f"Generating SQL for query: {natural_language_query}")
    try:
        start_time = time.time()
        
        schema_info = """
        CREATE TABLE sales (
          pizza_id DECIMAL(8,2) PRIMARY KEY,
          order_id DECIMAL(8,2),
          pizza_name_id VARCHAR(14),
          quantity DECIMAL(4,2),
          order_date DATE,
          order_time VARCHAR(8),
          unit_price DECIMAL(5,2),
          total_price DECIMAL(5,2),
          pizza_size VARCHAR(3),
          pizza_category VARCHAR(7),
          pizza_ingredients VARCHAR(97),
          pizza_name VARCHAR(42)
        );
        """
        
        prompt = f"""### Task: Generate a SQL query to answer the following question.
        ### Database Schema:
        {schema_info}
        ### Question: {natural_language_query}
        ### SQL Query:"""

        inputs = GLOBAL_TOKENIZER(
            prompt, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=512,
            return_attention_mask=True
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = GLOBAL_MODEL.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=256,
                temperature=0.1,
                do_sample=True,
                top_p=0.95,
                num_return_sequences=1,
                pad_token_id=GLOBAL_TOKENIZER.eos_token_id,
            )
            
        generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
        sql_query = generated_query.split("### SQL Query:")[-1].strip()
        
        logging.info(f"SQL generation took {time.time() - start_time:.2f} seconds")
        return sql_query

    except Exception as e:
        logging.error(f"Error generating SQL query: {str(e)}")
        return f"Error generating SQL query: {str(e)}"

def format_result(query_result):
    """Format query results efficiently"""
    if isinstance(query_result, str) and "Error" in query_result:
        logging.warning(f"Query result contains an error: {query_result}")
        return query_result
    
    if not query_result:
        logging.info("No results found.")
        return "No results found."
    
    # Use list comprehension for better performance
    if len(query_result) == 1:
        return "\n".join(f"{k}: {v}" for k, v in query_result[0].items())
    
    results = [f"Found {len(query_result)} results:\n"]
    for i, row in enumerate(query_result[:5], 1):
        results.append(f"Result {i}:")
        results.extend(f"{k}: {v}" for k, v in row.items())
        results.append("")
    
    if len(query_result) > 5:
        results.append(f"(Showing first 5 of {len(query_result)} results)")
    
    return "\n".join(results)

def check_live_connection():
    """Check the database connection status periodically."""
    while True:
        test_db_connection()
        time.sleep(10)  # Check every 10 seconds

def main():
    """Main function with Streamlit UI components"""
    st.title("Natural Language to SQL Query")
    st.write("Ask questions about pizza sales data in plain English.")
    
    # Start the live connection check in a separate thread
    threading.Thread(target=check_live_connection, daemon=True).start()
    
    # Test and display database connection status
    if db_connection_status:
        st.success("Database connection is live.")
    else:
        st.error("Database connection is not live.")
    
    # Initialize model
    initialize_model()
    
    # Input field for natural language query
    natural_language_query = st.text_input("Enter your question", placeholder="e.g., What were the total sales for each pizza category?")
    
    if st.button("Generate and Execute Query"):
        if natural_language_query:
            # Generate SQL query
            sql_query = generate_sql(natural_language_query)
            st.write("Generated SQL Query:", sql_query)

            # Execute the generated query
            query_result = execute_query(sql_query)
            formatted_result = format_result(query_result)
            
            st.write("Query Result:")
            st.code(json.dumps(query_result, indent=2))
            
            st.write("Human-Readable Response:")
            st.text(formatted_result)
        else:
            logging.warning("User did not enter a query.")
            st.write("Please enter a query.")

if __name__ == "__main__":
    main()