Suresh Beekhani commited on
Commit
5b26f53
·
1 Parent(s): 1cf406a

Your commit message

Browse files
Files changed (2) hide show
  1. requirements.txt +7 -8
  2. src/app.py +120 -87
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- streamlit==1.31.1
2
- langchain==0.1.8
3
- langchain-community==0.0.21
4
- langchain-core==0.1.24
5
- langchain-openai==0.0.6
6
- mysql-connector-python==8.3.0
7
- groq==0.4.2
8
- langchain-groq==0.0.1
 
1
+ streamlit
2
+ langchain
3
+ langchain-community
4
+ langchain-core
5
+ mysql-connector-python
6
+ groq
7
+ langchain-groq
 
src/app.py CHANGED
@@ -1,139 +1,172 @@
1
- from dotenv import load_dotenv
2
- from langchain_core.messages import AIMessage, HumanMessage
3
- from langchain_core.prompts import ChatPromptTemplate
4
- from langchain_core.runnables import RunnablePassthrough
5
- from langchain_community.utilities import SQLDatabase
6
- from langchain_core.output_parsers import StrOutputParser
7
- from langchain_openai import ChatOpenAI
8
- from langchain_groq import ChatGroq
9
- import streamlit as st
 
 
10
 
11
- def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
12
- db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
13
- return SQLDatabase.from_uri(db_uri)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
15
  def get_sql_chain(db):
16
- template = """
 
17
  You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
18
  Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
19
-
20
  <SCHEMA>{schema}</SCHEMA>
21
-
22
  Conversation History: {chat_history}
23
-
24
- Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
25
-
26
- For example:
27
- Question: which 3 artists have the most tracks?
28
- SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
29
- Question: Name 10 artists
30
- SQL Query: SELECT Name FROM Artist LIMIT 10;
31
-
32
- Your turn:
33
 
34
  Question: {question}
35
  SQL Query:
36
  """
37
 
38
- prompt = ChatPromptTemplate.from_template(template)
39
-
40
- # llm = ChatOpenAI(model="gpt-4-0125-preview")
41
- llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
42
-
43
- def get_schema(_):
44
- return db.get_table_info()
45
-
46
- return (
47
- RunnablePassthrough.assign(schema=get_schema)
48
- | prompt
49
- | llm
50
- | StrOutputParser()
51
- )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def get_response(user_query: str, db: SQLDatabase, chat_history: list):
54
- sql_chain = get_sql_chain(db)
55
-
56
- template = """
57
- You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
58
- Based on the table schema below, question, sql query, and sql response, write a natural language response.
 
59
  <SCHEMA>{schema}</SCHEMA>
60
-
61
  Conversation History: {chat_history}
62
  SQL Query: <SQL>{query}</SQL>
63
  User question: {question}
64
- SQL Response: {response}"""
65
-
66
- prompt = ChatPromptTemplate.from_template(template)
67
-
68
- # llm = ChatOpenAI(model="gpt-4-0125-preview")
69
- llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
70
-
71
- chain = (
72
- RunnablePassthrough.assign(query=sql_chain).assign(
73
- schema=lambda _: db.get_table_info(),
74
- response=lambda vars: db.run(vars["query"]),
 
 
 
 
 
 
 
75
  )
76
- | prompt
77
- | llm
78
- | StrOutputParser()
79
- )
80
-
81
- return chain.invoke({
82
- "question": user_query,
83
- "chat_history": chat_history,
84
- })
85
 
86
-
 
 
 
 
 
 
87
  if "chat_history" not in st.session_state:
 
88
  st.session_state.chat_history = [
89
- AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
90
  ]
91
 
92
- load_dotenv()
93
-
94
  st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
95
 
 
96
  st.title("Chat with MySQL")
97
 
 
98
  with st.sidebar:
99
  st.subheader("Settings")
100
- st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
101
 
102
- st.text_input("Host", value="localhost", key="Host")
103
- st.text_input("Port", value="3306", key="Port")
104
- st.text_input("User", value="root", key="User")
105
- st.text_input("Password", type="password", value="admin", key="Password")
106
- st.text_input("Database", value="Chinook", key="Database")
 
107
 
 
108
  if st.button("Connect"):
109
  with st.spinner("Connecting to database..."):
110
- db = init_database(
111
- st.session_state["User"],
112
- st.session_state["Password"],
113
- st.session_state["Host"],
114
- st.session_state["Port"],
115
- st.session_state["Database"]
116
- )
117
- st.session_state.db = db
118
- st.success("Connected to database!")
119
-
120
  for message in st.session_state.chat_history:
121
  if isinstance(message, AIMessage):
 
122
  with st.chat_message("AI"):
123
  st.markdown(message.content)
124
  elif isinstance(message, HumanMessage):
 
125
  with st.chat_message("Human"):
126
  st.markdown(message.content)
127
 
 
128
  user_query = st.chat_input("Type a message...")
129
- if user_query is not None and user_query.strip() != "":
 
130
  st.session_state.chat_history.append(HumanMessage(content=user_query))
131
 
 
132
  with st.chat_message("Human"):
133
  st.markdown(user_query)
134
 
 
135
  with st.chat_message("AI"):
136
  response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
137
  st.markdown(response)
138
 
139
- st.session_state.chat_history.append(AIMessage(content=response))
 
 
1
+ # Import necessary libraries and modules
2
+ from dotenv import load_dotenv # For loading environment variables from .env
3
+ from langchain_core.messages import AIMessage, HumanMessage # Message handling
4
+ from langchain_core.prompts import ChatPromptTemplate # Prompt templates for generating responses
5
+ from langchain_core.runnables import RunnablePassthrough # To chain operations
6
+ from langchain_community.utilities import SQLDatabase # SQL database utility for LangChain
7
+ from langchain_core.output_parsers import StrOutputParser # To parse outputs as strings
8
+ # OpenAI model for chat (if used)
9
+ from langchain_groq import ChatGroq # Groq model for chat (currently used)
10
+ import streamlit as st # Streamlit for building the web app
11
+ import os # To access environment variables
12
 
13
+ # Load environment variables from the .env file (like API keys, database credentials)
14
+ load_dotenv()
15
+
16
+ # Function to initialize a connection to a MySQL database
17
+ def init_database() -> SQLDatabase:
18
+ try:
19
+ # Load credentials from environment variables for better security
20
+ user = os.getenv("DB_USER", "root")
21
+ password = os.getenv("DB_PASSWORD", "admin")
22
+ host = os.getenv("DB_HOST", "localhost")
23
+ port = os.getenv("DB_PORT", "3306")
24
+ database = os.getenv("DB_NAME", "Chinook")
25
+
26
+ # Construct the database URI
27
+ db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
28
+
29
+ # Initialize and return the SQLDatabase instance
30
+ return SQLDatabase.from_uri(db_uri)
31
+ except Exception as e:
32
+ st.error(f"Failed to connect to database: {e}")
33
+ return None
34
 
35
+ # Function to create a chain that generates SQL queries from user input and conversation history
36
  def get_sql_chain(db):
37
+ # SQL prompt template
38
+ template = """
39
  You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
40
  Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
41
+
42
  <SCHEMA>{schema}</SCHEMA>
 
43
  Conversation History: {chat_history}
44
+ Write only the SQL query and nothing else.
 
 
 
 
 
 
 
 
 
45
 
46
  Question: {question}
47
  SQL Query:
48
  """
49
 
50
+ # Create a prompt from the above template
51
+ prompt = ChatPromptTemplate.from_template(template)
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
54
+ llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
55
+
56
+ # Helper function to get schema info from the database
57
+ def get_schema(_):
58
+ return db.get_table_info()
59
+
60
+ # Chain of operations:
61
+ # 1. Assign schema information from the database
62
+ # 2. Use the AI model to generate a SQL query
63
+ # 3. Parse the result into a string
64
+ return (
65
+ RunnablePassthrough.assign(schema=get_schema) # Get schema info from the database
66
+ | prompt # Generate SQL query from the prompt template
67
+ | llm # Use Groq model to process the prompt and return a SQL query
68
+ | StrOutputParser() # Parse the result as a string
69
+ )
70
+
71
+ # Function to generate a response in natural language based on the SQL query result
72
  def get_response(user_query: str, db: SQLDatabase, chat_history: list):
73
+ # Generate the SQL query using the chain
74
+ sql_chain = get_sql_chain(db)
75
+
76
+ # Prompt template for natural language response based on SQL query and result
77
+ template = """
78
+ You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
79
  <SCHEMA>{schema}</SCHEMA>
 
80
  Conversation History: {chat_history}
81
  SQL Query: <SQL>{query}</SQL>
82
  User question: {question}
83
+ SQL Response: {response}
84
+ """
85
+
86
+ # Create a natural language response prompt
87
+ prompt = ChatPromptTemplate.from_template(template)
88
+
89
+ # Initialize Groq model (alternative: OpenAI)
90
+ llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
91
+
92
+ # Build a chain: generate SQL query, run it on the database, generate a natural language response
93
+ chain = (
94
+ RunnablePassthrough.assign(query=sql_chain).assign(
95
+ schema=lambda _: db.get_table_info(), # Get schema info
96
+ response=lambda vars: db.run(vars["query"]), # Run SQL query on the database
97
+ )
98
+ | prompt # Use prompt to generate a natural language response
99
+ | llm # Process prompt with Groq model
100
+ | StrOutputParser() # Parse the final result as a string
101
  )
 
 
 
 
 
 
 
 
 
102
 
103
+ # Execute the chain and return the response
104
+ return chain.invoke({
105
+ "question": user_query,
106
+ "chat_history": chat_history,
107
+ })
108
+
109
+ # Initialize the Streamlit session
110
  if "chat_history" not in st.session_state:
111
+ # Initialize chat history with a welcome message from AI
112
  st.session_state.chat_history = [
113
+ AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
114
  ]
115
 
116
+ # Set up the Streamlit web page configuration
 
117
  st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
118
 
119
+ # Streamlit app title
120
  st.title("Chat with MySQL")
121
 
122
+ # Sidebar for database connection settings
123
  with st.sidebar:
124
  st.subheader("Settings")
125
+ st.write("Connect to your database and start chatting.")
126
 
127
+ # Database connection input fields
128
+ host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
129
+ port = st.text_input("Port", value=os.getenv("DB_PORT", "3306"))
130
+ user = st.text_input("User", value=os.getenv("DB_USER", "root"))
131
+ password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin"))
132
+ database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook"))
133
 
134
+ # Button to connect to the database
135
  if st.button("Connect"):
136
  with st.spinner("Connecting to database..."):
137
+ # Initialize the database connection and store in session state
138
+ db = init_database()
139
+ if db:
140
+ st.session_state.db = db
141
+ st.success("Connected to the database!")
142
+ else:
143
+ st.error("Connection failed. Please check your settings.")
144
+
145
+ # Display chat history
 
146
  for message in st.session_state.chat_history:
147
  if isinstance(message, AIMessage):
148
+ # Display AI message
149
  with st.chat_message("AI"):
150
  st.markdown(message.content)
151
  elif isinstance(message, HumanMessage):
152
+ # Display human message
153
  with st.chat_message("Human"):
154
  st.markdown(message.content)
155
 
156
+ # Input field for user's message
157
  user_query = st.chat_input("Type a message...")
158
+ if user_query and user_query.strip():
159
+ # Add user's query to the chat history
160
  st.session_state.chat_history.append(HumanMessage(content=user_query))
161
 
162
+ # Display user's message in the chat
163
  with st.chat_message("Human"):
164
  st.markdown(user_query)
165
 
166
+ # Generate and display AI's response based on the query
167
  with st.chat_message("AI"):
168
  response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
169
  st.markdown(response)
170
 
171
+ # Add AI's response to the chat history
172
+ st.session_state.chat_history.append(AIMessage(content=response))