alejandro
refactor: abstract out creation of mysql chain
efb8ba7
raw
history blame
2.97 kB
import streamlit as st
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
def initialize_database(host, port, username, password, database):
db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
return SQLDatabase.from_uri(db_uri)
def get_sql_chain(db):
template = """
Based on the table schema below, write a SQL query that would answer the user's question.
{schema}
Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI()
def get_schema(_):
return db.get_table_info()
return (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm.bind(stop="\nSQL Result:")
| StrOutputParser()
)
def get_response(user_query, chat_history, db):
sql_chain = get_sql_chain(db)
return sql_chain.invoke({
"question": user_query
})
load_dotenv()
st.set_page_config(initial_sidebar_state="expanded", page_title="Chat with a MySQL Database", page_icon=":speech_balloon:")
if 'chat_history' not in st.session_state:
st.session_state.chat_history = [
AIMessage(content="")
]
if 'db' not in st.session_state:
st.session_state.db = None
with st.sidebar:
st.title("Chat with a MySQL Database")
st.write("This is a simple chat application allows you to chat with a MySQL database.")
st.text_input("Host", key="name")
st.text_input("Port", key="port")
st.text_input("Username", key="username")
st.text_input("Password", key="password")
st.text_input("Database", key="database")
if st.button("Connect"):
with st.spinner("Connecting to the database..."):
st.session_state.db = initialize_database(
username=st.session_state.username,
password=st.session_state.password,
host=st.session_state.name,
port=st.session_state.port,
database=st.session_state.database
)
if st.session_state.db is not None:
st.success("Connected to the database!")
user_query = st.chat_input("Type a message...")
if user_query is not None and user_query != "":
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
response = get_response(
user_query,
st.session_state.chat_history,
st.session_state.db
)
print(f"Response generated: {response}")
st.markdown(response)
st.session_state.chat_history.append(AIMessage(content=response))