import os
import sqlite3
import re
import utilities as us

def utils_extract_db_schema_as_string(
    db_id, base_path, model : str | None = None , normalize=False, sql: str | None = None, get_insert_into: bool = False, prompt : str | None = None
):
    """
    Extracts the full schema of an SQLite database into a single string.

    :param base_path: Base path where the database is located.
    :param db_id: Path to the SQLite database file.
    :param normalize: Whether to normalize the schema string.
    :param sql: Optional SQL query to filter specific tables.
    :return: Schema of the database as a single string.
    """
    connection = sqlite3.connect(base_path)
    cursor = connection.cursor()

    # Get the schema entries based on the provided SQL query
    schema_entries = _get_schema_entries(cursor, sql, get_insert_into, model, prompt)

    # Combine all schema definitions into a single string
    schema_string = _combine_schema_entries(schema_entries, normalize)

    return schema_string



def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | None = None,  prompt : str | None = None):
    """
    Retrieves schema entries and optionally data entries from the SQLite database.

    :param cursor: SQLite cursor object.
    :param sql: Optional SQL query to filter specific tables.
    :param get_insert_into: Boolean flag to include INSERT INTO statements.
    :return: List of schema and optionally data entries.
    """
    entries = []

    if sql:
        # Extract table names from the provided SQL query
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
    else:
        # Retrieve all table names
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [tbl[0] for tbl in cursor.fetchall()]

    for table in tables:
        entries_per_table = []
        # Retrieve the CREATE TABLE statement for each table
        cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
        create_table_stmt = cursor.fetchone()
        if create_table_stmt:
            stmt = create_table_stmt[0].strip()
            if not stmt.endswith(';'):
                stmt += ';'
            entries_per_table.append(stmt)

        if get_insert_into:
            # Retrieve all data from the table
            cursor.execute(f"SELECT * FROM {table};")
            rows = cursor.fetchall()
            column_names = [description[0] for description in cursor.description]

            # Generate INSERT INTO statements for each row
            if model==None : 
                max_len=3
            else: 
                max_len = len(rows)

            for row in rows[:max_len]:
                values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
                insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
                entries_per_table.append(insert_stmt)
    
        if model != None : entries_per_table = us.crop_entries_per_token(entries_per_table, model, prompt)
        entries.extend(entries_per_table)

    return entries


def _combine_schema_entries(schema_entries, normalize):
    """
    Combines schema entries into a single string.

    :param schema_entries: List of schema entries.
    :param normalize: Whether to normalize the schema string.
    :return: Combined schema string.
    """
    if not normalize:
        return "\n".join(entry for entry in schema_entries)

    return "\n".join(
        re.sub(
            r"\s*\)",
            ")",
            re.sub(
                r"\(\s*",
                "(",
                re.sub(
                    r"(`\w+`)\s+\(",
                    r"\1(",
                    re.sub(
                        r"^\s*([^\s(]+)",
                        r"`\1`",
                        re.sub(
                            r"\s+",
                            " ",
                            entry.replace("CREATE TABLE", "").replace("\t", " "),
                        ).strip(),
                    ),
                ),
            ),
        )
        for entry in schema_entries
    )


def create_db_temp(schema_sql: str) -> sqlite3.Connection:
    """
    Creates a temporary SQLite database in memory by executing the provided SQL schema.
    
    Args:
        schema_sql (str): The SQL code containing CREATE TABLE and INSERT INTO.
        
    Returns:
        sqlite3.Connection: Connection object to the temporary database.
    """
    conn = sqlite3.connect(':memory:')
    cursor = conn.cursor()
    
    try:
        cursor.executescript(schema_sql)
        conn.commit()
    except sqlite3.Error as e:
        conn.close()
        raise

    return conn