File size: 4,876 Bytes
aff05a7
 
 
afa581c
32b6873
aff05a7
afa581c
aff05a7
 
 
 
 
 
 
 
 
 
 
 
 
 
afa581c
aff05a7
 
 
 
 
 
 
2321bd0
afa581c
aff05a7
2321bd0
aff05a7
 
 
2321bd0
 
aff05a7
2321bd0
 
aff05a7
2321bd0
aff05a7
 
 
2321bd0
 
 
 
 
af2b1fd
2321bd0
 
 
 
af2b1fd
 
 
 
2321bd0
 
 
 
 
 
 
 
afa581c
 
 
 
 
 
2321bd0
 
af2b1fd
afa581c
af2b1fd
 
aff05a7
2321bd0
aff05a7
 
 
 
 
 
 
 
 
 
 
2321bd0
aff05a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2321bd0
aff05a7
 
 
 
 
 
 
af2b1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import sqlite3
import re
import utilities as us

def utils_extract_db_schema_as_string(
    db_id, base_path, model : str | None = None , normalize=False, sql: str | None = None, get_insert_into: bool = False, prompt : str | None = None
):
    """
    Extracts the full schema of an SQLite database into a single string.

    :param base_path: Base path where the database is located.
    :param db_id: Path to the SQLite database file.
    :param normalize: Whether to normalize the schema string.
    :param sql: Optional SQL query to filter specific tables.
    :return: Schema of the database as a single string.
    """
    connection = sqlite3.connect(base_path)
    cursor = connection.cursor()

    # Get the schema entries based on the provided SQL query
    schema_entries = _get_schema_entries(cursor, sql, get_insert_into, model, prompt)

    # Combine all schema definitions into a single string
    schema_string = _combine_schema_entries(schema_entries, normalize)

    return schema_string



def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | None = None,  prompt : str | None = None):
    """
    Retrieves schema entries and optionally data entries from the SQLite database.

    :param cursor: SQLite cursor object.
    :param sql: Optional SQL query to filter specific tables.
    :param get_insert_into: Boolean flag to include INSERT INTO statements.
    :return: List of schema and optionally data entries.
    """
    entries = []

    if sql:
        # Extract table names from the provided SQL query
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
    else:
        # Retrieve all table names
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [tbl[0] for tbl in cursor.fetchall()]

    for table in tables:
        entries_per_table = []
        # Retrieve the CREATE TABLE statement for each table
        cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
        create_table_stmt = cursor.fetchone()
        if create_table_stmt:
            stmt = create_table_stmt[0].strip()
            if not stmt.endswith(';'):
                stmt += ';'
            entries_per_table.append(stmt)

        if get_insert_into:
            # Retrieve all data from the table
            cursor.execute(f"SELECT * FROM {table};")
            rows = cursor.fetchall()
            column_names = [description[0] for description in cursor.description]

            # Generate INSERT INTO statements for each row
            if model==None : 
                max_len=3
            else: 
                max_len = len(rows)

            for row in rows[:max_len]:
                values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
                insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
                entries_per_table.append(insert_stmt)
    
        if model != None : entries_per_table = us.crop_entries_per_token(entries_per_table, model, prompt)
        entries.extend(entries_per_table)

    return entries


def _combine_schema_entries(schema_entries, normalize):
    """
    Combines schema entries into a single string.

    :param schema_entries: List of schema entries.
    :param normalize: Whether to normalize the schema string.
    :return: Combined schema string.
    """
    if not normalize:
        return "\n".join(entry for entry in schema_entries)

    return "\n".join(
        re.sub(
            r"\s*\)",
            ")",
            re.sub(
                r"\(\s*",
                "(",
                re.sub(
                    r"(`\w+`)\s+\(",
                    r"\1(",
                    re.sub(
                        r"^\s*([^\s(]+)",
                        r"`\1`",
                        re.sub(
                            r"\s+",
                            " ",
                            entry.replace("CREATE TABLE", "").replace("\t", " "),
                        ).strip(),
                    ),
                ),
            ),
        )
        for entry in schema_entries
    )


def create_db_temp(schema_sql: str) -> sqlite3.Connection:
    """
    Creates a temporary SQLite database in memory by executing the provided SQL schema.
    
    Args:
        schema_sql (str): The SQL code containing CREATE TABLE and INSERT INTO.
        
    Returns:
        sqlite3.Connection: Connection object to the temporary database.
    """
    conn = sqlite3.connect(':memory:')
    cursor = conn.cursor()
    
    try:
        cursor.executescript(schema_sql)
        conn.commit()
    except sqlite3.Error as e:
        conn.close()
        raise

    return conn