# Import necessary libraries and modules from dotenv import load_dotenv # For loading environment variables from .env from langchain_core.messages import AIMessage, HumanMessage # Message handling from langchain_core.prompts import ChatPromptTemplate # Prompt templates for generating responses from langchain_core.runnables import RunnablePassthrough # To chain operations from langchain_community.utilities import SQLDatabase # SQL database utility for LangChain from langchain_core.output_parsers import StrOutputParser # To parse outputs as strings # OpenAI model for chat (if used) from langchain_groq import ChatGroq # Groq model for chat (currently used) import streamlit as st # Streamlit for building the web app import os # To access environment variables # Load environment variables from the .env file (like API keys, database credentials) load_dotenv() # Function to initialize a connection to a MySQL database def init_database() -> SQLDatabase: try: # Load credentials from environment variables for better security user = os.getenv("DB_USER", "root") password = os.getenv("DB_PASSWORD", "admin") host = os.getenv("DB_HOST", "localhost") port = os.getenv("DB_PORT", "3306") database = os.getenv("DB_NAME", "Chinook") # Construct the database URI db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}" # Initialize and return the SQLDatabase instance return SQLDatabase.from_uri(db_uri) except Exception as e: st.error(f"Failed to connect to database: {e}") return None # Function to create a chain that generates SQL queries from user input and conversation history def get_sql_chain(db): # SQL prompt template template = """ You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database. Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account. {schema} Conversation History: {chat_history} Write only the SQL query and nothing else. Question: {question} SQL Query: """ # Create a prompt from the above template prompt = ChatPromptTemplate.from_template(template) # Initialize Groq model for generating SQL queries (can switch to OpenAI if needed) llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) # Helper function to get schema info from the database def get_schema(_): return db.get_table_info() # Chain of operations: # 1. Assign schema information from the database # 2. Use the AI model to generate a SQL query # 3. Parse the result into a string return ( RunnablePassthrough.assign(schema=get_schema) # Get schema info from the database | prompt # Generate SQL query from the prompt template | llm # Use Groq model to process the prompt and return a SQL query | StrOutputParser() # Parse the result as a string ) # Function to generate a response in natural language based on the SQL query result def get_response(user_query: str, db: SQLDatabase, chat_history: list): # Generate the SQL query using the chain sql_chain = get_sql_chain(db) # Prompt template for natural language response based on SQL query and result template = """ You are a data analyst at a company. Based on the table schema, SQL query, and response, write a natural language response. {schema} Conversation History: {chat_history} SQL Query: {query} User question: {question} SQL Response: {response} """ # Create a natural language response prompt prompt = ChatPromptTemplate.from_template(template) # Initialize Groq model (alternative: OpenAI) llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0) # Build a chain: generate SQL query, run it on the database, generate a natural language response chain = ( RunnablePassthrough.assign(query=sql_chain).assign( schema=lambda _: db.get_table_info(), # Get schema info response=lambda vars: db.run(vars["query"]), # Run SQL query on the database ) | prompt # Use prompt to generate a natural language response | llm # Process prompt with Groq model | StrOutputParser() # Parse the final result as a string ) # Execute the chain and return the response return chain.invoke({ "question": user_query, "chat_history": chat_history, }) # Initialize the Streamlit session if "chat_history" not in st.session_state: # Initialize chat history with a welcome message from AI st.session_state.chat_history = [ AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."), ] # Set up the Streamlit web page configuration st.set_page_config(page_title="Chat with MySQL", page_icon=":speech_balloon:") # Streamlit app title st.title("Chat with MySQL") # Sidebar for database connection settings with st.sidebar: st.subheader("Settings") st.write("Connect to your database and start chatting.") # Database connection input fields host = st.text_input("Host", value=os.getenv("DB_HOST", "localhost")) port = st.text_input("Port", value=os.getenv("DB_PORT", "3306")) user = st.text_input("User", value=os.getenv("DB_USER", "root")) password = st.text_input("Password", type="password", value=os.getenv("DB_PASSWORD", "admin")) database = st.text_input("Database", value=os.getenv("DB_NAME", "Chinook")) # Button to connect to the database if st.button("Connect"): with st.spinner("Connecting to database..."): # Initialize the database connection and store in session state db = init_database() if db: st.session_state.db = db st.success("Connected to the database!") else: st.error("Connection failed. Please check your settings.") # Display chat history for message in st.session_state.chat_history: if isinstance(message, AIMessage): # Display AI message with st.chat_message("AI"): st.markdown(message.content) elif isinstance(message, HumanMessage): # Display human message with st.chat_message("Human"): st.markdown(message.content) # Input field for user's message user_query = st.chat_input("Type a message...") if user_query and user_query.strip(): # Add user's query to the chat history st.session_state.chat_history.append(HumanMessage(content=user_query)) # Display user's message in the chat with st.chat_message("Human"): st.markdown(user_query) # Generate and display AI's response based on the query with st.chat_message("AI"): response = get_response(user_query, st.session_state.db, st.session_state.chat_history) st.markdown(response) # Add AI's response to the chat history st.session_state.chat_history.append(AIMessage(content=response))