Ravi theja K commited on
Commit
749e634
·
verified ·
1 Parent(s): d10034d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from snowflake.snowpark import Session
4
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
5
+ from langchain_community.utilities import SQLDatabase
6
+ from langchain_openai import OpenAI
7
+ from langchain.chains import create_sql_query_chain
8
+
9
+ @st.cache_resource(show_spinner="Connecting...")
10
+ def getSession():
11
+ pars = SnowflakeLoginOptions("test_conn")
12
+ pars["database"] = "SNOWFLAKE_SAMPLE_DATA"
13
+ pars["schema"] = "TPCH_SF1"
14
+ session = Session.builder.configs(pars).create()
15
+
16
+ url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}"
17
+ + f"/{pars['database']}/{pars['schema']}"
18
+ + f"?warehouse={pars['warehouse']}&role={pars['role']}")
19
+ db = SQLDatabase.from_uri(url)
20
+
21
+ openai_key = os.environ["OPENAI_API_KEY"]
22
+ llm = OpenAI(openai_api_key=openai_key)
23
+ chain = create_sql_query_chain(llm, db)
24
+ return session, db, chain
25
+
26
+
27
+ st.title("SQL Query Generator")
28
+ st.write("Returns and runs queries from questions in natural language.")
29
+
30
+ session, db, chain = getSession()
31
+
32
+ question = st.sidebar.text_area("Ask a question:",
33
+ value="Show me the total number of entries in the first table")
34
+ sql = chain.invoke({"question": question}).rstrip(';')
35
+
36
+ tabQuery, tabData, tabLog = st.tabs(["Query", "Data", "Log"])
37
+ tabQuery.code(sql, language="sql")
38
+ tabData.dataframe(session.sql(sql))
39
+ tabLog.code(db.table_info, language="sql")