File size: 2,966 Bytes
7df1f40
ec96023
c796379
 
 
 
efb8ba7
c796379
ec96023
 
 
 
 
efb8ba7
 
 
 
c796379
efb8ba7
 
 
c796379
efb8ba7
 
 
 
 
 
c796379
efb8ba7
 
 
 
 
 
c796379
 
efb8ba7
 
 
c796379
 
 
 
 
 
7df1f40
 
 
c796379
 
 
 
 
 
 
 
7df1f40
 
 
 
 
 
 
 
 
 
ec96023
 
c796379
ec96023
 
 
 
 
 
c796379
 
ec96023
7df1f40
 
c796379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv

def initialize_database(host, port, username, password, database):
  db_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
  return SQLDatabase.from_uri(db_uri)

def get_sql_chain(db):
    template = """
    Based on the table schema below, write a SQL query that would answer the user's question.
    {schema}

    Question: {question}
    SQL Query:
    """

    prompt = ChatPromptTemplate.from_template(template)
    
    llm = ChatOpenAI()
    
    def get_schema(_):
      return db.get_table_info()

    return (
        RunnablePassthrough.assign(schema=get_schema)
        | prompt
        | llm.bind(stop="\nSQL Result:")
        | StrOutputParser()
    )
  

def get_response(user_query, chat_history, db):

  sql_chain = get_sql_chain(db)
  
  return sql_chain.invoke({
    "question": user_query
  })
  
load_dotenv()

st.set_page_config(initial_sidebar_state="expanded", page_title="Chat with a MySQL Database", page_icon=":speech_balloon:")

if 'chat_history' not in st.session_state:
  st.session_state.chat_history = [
    AIMessage(content="")
  ]
  
if 'db' not in st.session_state:
  st.session_state.db = None

with st.sidebar:
    st.title("Chat with a MySQL Database")
    st.write("This is a simple chat application allows you to chat with a MySQL database.")

    st.text_input("Host", key="name")
    st.text_input("Port", key="port")
    st.text_input("Username", key="username")
    st.text_input("Password", key="password")
    st.text_input("Database", key="database")
    
    if st.button("Connect"):
        with st.spinner("Connecting to the database..."):
          st.session_state.db = initialize_database(
            username=st.session_state.username,
            password=st.session_state.password,
            host=st.session_state.name,
            port=st.session_state.port,
            database=st.session_state.database
          )
          if st.session_state.db is not None:
            st.success("Connected to the database!")
    
user_query = st.chat_input("Type a message...")

if user_query is not None and user_query != "":
    st.session_state.chat_history.append(HumanMessage(content=user_query))

    with st.chat_message("Human"):
        st.markdown(user_query)

    with st.chat_message("AI"):
        response = get_response(
            user_query, 
            st.session_state.chat_history,
            st.session_state.db
          )
        print(f"Response generated: {response}")
        st.markdown(response)
        
    st.session_state.chat_history.append(AIMessage(content=response))