Spaces:
Sleeping
Sleeping
File size: 7,237 Bytes
5b26f53 ec96023 5b26f53 ec96023 5b26f53 efb8ba7 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb efb8ba7 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb 5b26f53 5bea3fb 7df1f40 5b26f53 5bea3fb c796379 5b26f53 7df1f40 5bea3fb 5b26f53 5bea3fb 5b26f53 7df1f40 5b26f53 ec96023 5bea3fb 5b26f53 411b037 5b26f53 411b037 5bea3fb 411b037 5b26f53 411b037 5bea3fb 411b037 5b26f53 5bea3fb 5b26f53 c796379 5bea3fb 5b26f53 c796379 5bea3fb 5b26f53 c796379 5bea3fb c796379 5b26f53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# Import necessary libraries and modules
from dotenv import load_dotenv # For loading environment variables from .env
from langchain_core.messages import AIMessage, HumanMessage # Message handling
from langchain_core.prompts import ChatPromptTemplate # Prompt templates for generating responses
from langchain_core.runnables import RunnablePassthrough # To chain operations
from langchain_community.utilities import SQLDatabase # SQL database utility for LangChain
from langchain_core.output_parsers import StrOutputParser # To parse outputs as strings
# OpenAI model for chat (if used)
from langchain_groq import ChatGroq # Groq model for chat (currently used)
import streamlit as st # Streamlit for building the web app
import os # To access environment variables
# Load environment variables from the .env file (like API keys, database credentials)
load_dotenv()
# Function to initialize a connection to a MySQL database
def init_database() -> SQLDatabase:
try:
# Load credentials from environment variables for better security
user = os.getenv("DB_USER", "root")
password = os.getenv("DB_PASSWORD", "admin")
host = os.getenv("DB_HOST", "localhost")
port = os.getenv("DB_PORT", "3306")
database = os.getenv("DB_NAME", "Chinook")
# Construct the database URI
db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
# Initialize and return the SQLDatabase instance
return SQLDatabase.from_uri(db_uri)
except Exception as e:
st.error(f"Failed to connect to database: {e}")
return None
# Function to create a chain that generates SQL queries from user input and conversation history
def get_sql_chain(db):
# SQL prompt template
template = """
You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
Write only the SQL query and nothing else.
Question: {question}
SQL Query:
"""
# Create a prompt from the above template
prompt = ChatPromptTemplate.from_template(template)
# Initialize Groq model for generating SQL queries (can switch to OpenAI if needed)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Helper function to get schema info from the database
def get_schema(_):
return db.get_table_info()
# Chain of operations:
# 1. Assign schema information from the database
# 2. Use the AI model to generate a SQL query
# 3. Parse the result into a string
return (
RunnablePassthrough.assign(schema=get_schema) # Get schema info from the database
| prompt # Generate SQL query from the prompt template
| llm # Use Groq model to process the prompt and return a SQL query
| StrOutputParser() # Parse the result as a string
)
# Function to generate a response in natural language based on the SQL query result
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
# Generate the SQL query using the chain
sql_chain = get_sql_chain(db)
# Prompt template for natural language response based on SQL query and result
template = """
You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response.
<SCHEMA>{schema}</SCHEMA>
Conversation History: {chat_history}
SQL Query: <SQL>{query}</SQL>
User question: {question}
SQL Response: {response}
"""
# Create a natural language response prompt
prompt = ChatPromptTemplate.from_template(template)
# Initialize Groq model (alternative: OpenAI)
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
# Build a chain: generate SQL query, run it on the database, generate a natural language response
chain = (
RunnablePassthrough.assign(query=sql_chain).assign(
schema=lambda _: db.get_table_info(), # Get schema info
response=lambda vars: db.run(vars["query"]), # Run SQL query on the database
)
| prompt # Use prompt to generate a natural language response
| llm # Process prompt with Groq model
| StrOutputParser() # Parse the final result as a string
)
# Execute the chain and return the response
return chain.invoke({
"question": user_query,
"chat_history": chat_history,
})
# Initialize the Streamlit session
if "chat_history" not in st.session_state:
# Initialize chat history with a welcome message from AI
st.session_state.chat_history = [
AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
]
# Set up the Streamlit web page configuration
st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:")
# Streamlit app title
st.title("Chat with MySQL")
# Sidebar for database connection settings
with st.sidebar:
st.subheader("Settings")
st.write("Connect to your database and start chatting.")
# Database connection input fields
host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost"))
port = st.text_input("Port", value=os.getenv("DB_PORT", "3306"))
user = st.text_input("User", value=os.getenv("DB_USER", "root"))
password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin"))
database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook"))
# Button to connect to the database
if st.button("Connect"):
with st.spinner("Connecting to database..."):
# Initialize the database connection and store in session state
db = init_database()
if db:
st.session_state.db = db
st.success("Connected to the database!")
else:
st.error("Connection failed. Please check your settings.")
# Display chat history
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
# Display AI message
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
# Display human message
with st.chat_message("Human"):
st.markdown(message.content)
# Input field for user's message
user_query = st.chat_input("Type a message...")
if user_query and user_query.strip():
# Add user's query to the chat history
st.session_state.chat_history.append(HumanMessage(content=user_query))
# Display user's message in the chat
with st.chat_message("Human"):
st.markdown(user_query)
# Generate and display AI's response based on the query
with st.chat_message("AI"):
response = get_response(user_query, st.session_state.db, st.session_state.chat_history)
st.markdown(response)
# Add AI's response to the chat history
st.session_state.chat_history.append(AIMessage(content=response))
|