File size: 7,237 Bytes
5b26f53
 
 
 
 
 
 
 
 
 
 
ec96023
5b26f53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec96023
5b26f53
efb8ba7
5b26f53
 
5bea3fb
 
5b26f53
5bea3fb
 
5b26f53
5bea3fb
efb8ba7
 
 
 
5b26f53
 
5bea3fb
5b26f53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bea3fb
5b26f53
 
 
 
 
 
5bea3fb
 
 
 
5b26f53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bea3fb
 
5b26f53
 
 
 
 
 
 
5bea3fb
5b26f53
5bea3fb
5b26f53
5bea3fb
 
5b26f53
5bea3fb
7df1f40
5b26f53
5bea3fb
c796379
5b26f53
7df1f40
5bea3fb
5b26f53
5bea3fb
5b26f53
 
 
 
 
 
7df1f40
5b26f53
ec96023
5bea3fb
5b26f53
 
 
 
 
 
 
 
 
411b037
 
5b26f53
411b037
5bea3fb
411b037
5b26f53
411b037
5bea3fb
411b037
5b26f53
5bea3fb
5b26f53
 
c796379
5bea3fb
5b26f53
c796379
 
5bea3fb
5b26f53
c796379
5bea3fb
 
c796379
5b26f53
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Import necessary libraries and modules
from dotenv import load_dotenv   # For loading environment variables from .env
from langchain_core.messages import AIMessage, HumanMessage  # Message handling
from langchain_core.prompts import ChatPromptTemplate  # Prompt templates for generating responses
from langchain_core.runnables import RunnablePassthrough  # To chain operations
from langchain_community.utilities import SQLDatabase  # SQL database utility for LangChain
from langchain_core.output_parsers import StrOutputParser  # To parse outputs as strings
# OpenAI model for chat (if used)
from langchain_groq import ChatGroq  # Groq model for chat (currently used)
import streamlit as st  # Streamlit for building the web app
import os  # To access environment variables

# Load environment variables from the .env file (like API keys, database credentials)
load_dotenv()

# Function to initialize a connection to a MySQL database
def init_database() -> SQLDatabase:
    try:
        # Load credentials from environment variables for better security
        user = os.getenv("DB_USER", "root")
        password = os.getenv("DB_PASSWORD", "admin")
        host = os.getenv("DB_HOST", "localhost")
        port = os.getenv("DB_PORT", "3306")
        database = os.getenv("DB_NAME", "Chinook")

        # Construct the database URI
        db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
        
        # Initialize and return the SQLDatabase instance
        return SQLDatabase.from_uri(db_uri)
    except Exception as e:
        st.error(f"Failed to connect to database: {e}")
        return None

# Function to create a chain that generates SQL queries from user input and conversation history
def get_sql_chain(db):
    # SQL prompt template
    template = """
    You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
    Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.

    <SCHEMA>{schema}</SCHEMA>
    Conversation History: {chat_history}
    Write only the SQL query and nothing else.
    
    Question: {question}
    SQL Query:
    """
    
    # Create a prompt from the above template
    prompt = ChatPromptTemplate.from_template(template)
    
    # Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
    
    # Helper function to get schema info from the database
    def get_schema(_):
        return db.get_table_info()
    
    # Chain of operations: 
    # 1. Assign schema information from the database
    # 2. Use the AI model to generate a SQL query
    # 3. Parse the result into a string
    return (
        RunnablePassthrough.assign(schema=get_schema)  # Get schema info from the database
        | prompt  # Generate SQL query from the prompt template
        | llm  # Use Groq model to process the prompt and return a SQL query
        | StrOutputParser()  # Parse the result as a string
    )

# Function to generate a response in natural language based on the SQL query result
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
    # Generate the SQL query using the chain
    sql_chain = get_sql_chain(db)
    
    # Prompt template for natural language response based on SQL query and result
    template = """
    You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
    <SCHEMA>{schema}</SCHEMA>
    Conversation History: {chat_history}
    SQL Query: <SQL>{query}</SQL>
    User question: {question}
    SQL Response: {response}
    """
    
    # Create a natural language response prompt
    prompt = ChatPromptTemplate.from_template(template)
    
    # Initialize Groq model (alternative: OpenAI)
    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
    
    # Build a chain: generate SQL query, run it on the database, generate a natural language response
    chain = (
        RunnablePassthrough.assign(query=sql_chain).assign(
            schema=lambda _: db.get_table_info(),  # Get schema info
            response=lambda vars: db.run(vars["query"]),  # Run SQL query on the database
        )
        | prompt  # Use prompt to generate a natural language response
        | llm  # Process prompt with Groq model
        | StrOutputParser()  # Parse the final result as a string
    )
    
    # Execute the chain and return the response
    return chain.invoke({
        "question": user_query,
        "chat_history": chat_history,
    })

# Initialize the Streamlit session
if "chat_history" not in st.session_state:
    # Initialize chat history with a welcome message from AI
    st.session_state.chat_history = [
        AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
    ]

# Set up the Streamlit web page configuration
st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")

# Streamlit app title
st.title("Chat with MySQL")

# Sidebar for database connection settings
with st.sidebar:
    st.subheader("Settings")
    st.write("Connect to your database and start chatting.")
    
    # Database connection input fields
    host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
    port = st.text_input("Port", value=os.getenv("DB_PORT", "3306"))
    user = st.text_input("User", value=os.getenv("DB_USER", "root"))
    password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin"))
    database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook"))
    
    # Button to connect to the database
    if st.button("Connect"):
        with st.spinner("Connecting to database..."):
            # Initialize the database connection and store in session state
            db = init_database()
            if db:
                st.session_state.db = db
                st.success("Connected to the database!")
            else:
                st.error("Connection failed. Please check your settings.")

# Display chat history
for message in st.session_state.chat_history:
    if isinstance(message, AIMessage):
        # Display AI message
        with st.chat_message("AI"):
            st.markdown(message.content)
    elif isinstance(message, HumanMessage):
        # Display human message
        with st.chat_message("Human"):
            st.markdown(message.content)

# Input field for user's message
user_query = st.chat_input("Type a message...")
if user_query and user_query.strip():
    # Add user's query to the chat history
    st.session_state.chat_history.append(HumanMessage(content=user_query))
    
    # Display user's message in the chat
    with st.chat_message("Human"):
        st.markdown(user_query)
        
    # Generate and display AI's response based on the query
    with st.chat_message("AI"):
        response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
        st.markdown(response)
        
    # Add AI's response to the chat history
    st.session_state.chat_history.append(AIMessage(content=response))