Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import snowflake.connector
|
3 |
+
import replicate
|
4 |
+
import re
|
5 |
+
import pandas as pd
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
# Snowflake connection parameters
|
9 |
+
ACCOUNT = "EZ97576.ap-southeast-1"
|
10 |
+
USER = "sureshsnowflake"
|
11 |
+
PASSWORD = "Slavia@123"
|
12 |
+
WAREHOUSE = "COMPUTE_WH"
|
13 |
+
DATABASE = "SNOWFLAKE_SAMPLE_DATA"
|
14 |
+
SCHEMA = "TPCDS_SF100TCL"
|
15 |
+
|
16 |
+
# Replicate API key
|
17 |
+
os.environ['REPLICATE_API_TOKEN'] = 'r8_E7Rn49bbi2O33bztSMYLKyqvWmo68mZ1Tg8M0'
|
18 |
+
|
19 |
+
|
20 |
+
def interact_with_replicate(prompt):
|
21 |
+
response = ""
|
22 |
+
for event in replicate.stream(
|
23 |
+
"snowflake/snowflake-arctic-instruct",
|
24 |
+
input={
|
25 |
+
"prompt": prompt,
|
26 |
+
"max_new_tokens": 250
|
27 |
+
},
|
28 |
+
):
|
29 |
+
response += str(event)
|
30 |
+
|
31 |
+
# Use regular expressions to extract SQL statements
|
32 |
+
sql_statements = re.findall(
|
33 |
+
r"(SELECT.*?;|INSERT.*?;|UPDATE.*?;|DELETE.*?;|CREATE.*?;|ALTER.*?;|DROP.*?;)", response, re.DOTALL | re.IGNORECASE)
|
34 |
+
|
35 |
+
# Join the SQL statements into a single string
|
36 |
+
return "\n".join(sql_statements)
|
37 |
+
|
38 |
+
|
39 |
+
def get_snowflake_connection():
|
40 |
+
return snowflake.connector.connect(
|
41 |
+
user=USER,
|
42 |
+
password=PASSWORD,
|
43 |
+
account=ACCOUNT,
|
44 |
+
warehouse=WAREHOUSE,
|
45 |
+
database=DATABASE,
|
46 |
+
schema=SCHEMA
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
def fetch_ddl_for_all_tables():
|
51 |
+
conn = get_snowflake_connection()
|
52 |
+
# Add headers
|
53 |
+
ddl_data = ["table_name, column_name, data_type, is_nullable\n"]
|
54 |
+
|
55 |
+
try:
|
56 |
+
cur = conn.cursor()
|
57 |
+
table_names_query = f"""
|
58 |
+
SELECT TABLE_NAME
|
59 |
+
FROM {DATABASE}.INFORMATION_SCHEMA.TABLES
|
60 |
+
WHERE TABLE_SCHEMA = '{SCHEMA}'
|
61 |
+
"""
|
62 |
+
cur.execute(table_names_query)
|
63 |
+
table_names = cur.fetchall()
|
64 |
+
|
65 |
+
for table_name in table_names:
|
66 |
+
table_name = table_name[0]
|
67 |
+
ddl_query = f"""
|
68 |
+
SELECT table_name, column_name, data_type, is_nullable
|
69 |
+
FROM {DATABASE}.INFORMATION_SCHEMA.COLUMNS
|
70 |
+
WHERE table_name = '{table_name}' AND table_schema = '{SCHEMA}'
|
71 |
+
"""
|
72 |
+
cur.execute(ddl_query)
|
73 |
+
ddl_result = cur.fetchall()
|
74 |
+
if ddl_result:
|
75 |
+
ddl = "\n".join(
|
76 |
+
[f"{row[0]}, {row[1]}, {row[2]}, {row[3]}" for row in ddl_result])
|
77 |
+
else:
|
78 |
+
ddl = f"No DDL found for table {table_name}"
|
79 |
+
ddl_data.append(f"-- DDL for table {table_name} --\n{ddl}\n\n")
|
80 |
+
|
81 |
+
with open('sample_ddl.txt', 'w') as file:
|
82 |
+
file.writelines(ddl_data)
|
83 |
+
|
84 |
+
st.success("DDLs written to sample_ddl.txt")
|
85 |
+
finally:
|
86 |
+
cur.close()
|
87 |
+
conn.close()
|
88 |
+
|
89 |
+
|
90 |
+
def generate_sql_query(sample_message):
|
91 |
+
with open('sample_ddl.txt', 'r') as file:
|
92 |
+
ddl_commands = file.read()
|
93 |
+
|
94 |
+
instruction = "Read the all provided ddl statements and work on the statement. if any logical questions asked,find relattion between tables based on primarykey and foreign key difinition use snowflake supported functions to provide snowflake sql construct. validate functions used on coulmn data types,amiguity in joins before present also add as many description columns as possible. Respond only with the SQL query without any explanations or contextual details. if the ask is to find or provide then consider as a request for writing sql construct"
|
95 |
+
combined_input = ddl_commands + sample_message + " " + instruction
|
96 |
+
|
97 |
+
response = interact_with_replicate(combined_input)
|
98 |
+
|
99 |
+
with open('generated_query.sql', 'w') as file:
|
100 |
+
file.write(response)
|
101 |
+
|
102 |
+
st.success("Response from Replicate has been written to 'generated_query.sql'")
|
103 |
+
|
104 |
+
|
105 |
+
def execute_generated_sql():
|
106 |
+
with open('generated_query.sql', 'r') as file:
|
107 |
+
generated_sql = file.read()
|
108 |
+
|
109 |
+
# Print the generated SQL for inspection
|
110 |
+
st.text_area("Generated SQL query", generated_sql, height=200)
|
111 |
+
|
112 |
+
conn = get_snowflake_connection()
|
113 |
+
|
114 |
+
try:
|
115 |
+
cur = conn.cursor()
|
116 |
+
cur.execute(generated_sql)
|
117 |
+
|
118 |
+
if cur.description is not None:
|
119 |
+
result = cur.fetchall()
|
120 |
+
columns = [desc[0] for desc in cur.description]
|
121 |
+
df = pd.DataFrame(result, columns=columns)
|
122 |
+
st.write("Result from executed SQL query:")
|
123 |
+
st.dataframe(df)
|
124 |
+
else:
|
125 |
+
st.write(
|
126 |
+
"The executed SQL did not return any results or is not a SELECT query.")
|
127 |
+
finally:
|
128 |
+
cur.close()
|
129 |
+
conn.close()
|
130 |
+
|
131 |
+
|
132 |
+
def main():
|
133 |
+
st.title("Snowflake and Replicate Integration")
|
134 |
+
|
135 |
+
st.header("Generate SQL Query and Execute")
|
136 |
+
sample_message = st.text_area(
|
137 |
+
"Enter your message for generating SQL query", height=100)
|
138 |
+
if st.button("Generate SQL"):
|
139 |
+
generate_sql_query(sample_message)
|
140 |
+
st.success("SQL Query generated successfully. You can now execute it.")
|
141 |
+
|
142 |
+
if st.button("Execute SQL"):
|
143 |
+
execute_generated_sql()
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
# Fetch DDLs for all tables automatically before starting the app
|
148 |
+
fetch_ddl_for_all_tables()
|
149 |
+
main()
|