Spaces:
Runtime error
Runtime error
# main.py | |
import re | |
import streamlit as st | |
from langchain_core.messages import HumanMessage, AIMessage | |
from utils.llm_logic import generate_llm_response | |
from utils.sql_utils import ( | |
extract_sql_command, | |
load_defaultdb_schema_text, | |
load_defaultdb_queries, | |
load_data, | |
) | |
from utils.handle_sql_commands import execute_sql_duckdb | |
st.set_page_config( | |
page_title="Text-to-SQL Agent", | |
page_icon="🤖", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
default_db_questions = {} | |
default_dfs = load_data() | |
selected_df = default_dfs | |
use_default_schema = True | |
llm_option = "gemini" | |
st.markdown( | |
""" | |
<style> | |
/* Base styles for both themes */ | |
.stPageLink { | |
background-image: linear-gradient(to right, #007BFF, #6610F2); /* Gradient background */ | |
color: white !important; /* Ensure text is readable on the gradient */ | |
padding: 12px 20px !important; /* Slightly larger padding */ | |
border-radius: 8px !important; /* More rounded corners */ | |
border: none !important; /* Remove default border */ | |
text-decoration: none !important; | |
font-weight: 500 !important; /* Slightly lighter font weight */ | |
transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out; /* Smooth transitions */ | |
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.15); /* Subtle shadow for depth */ | |
display: inline-flex; | |
align-items: center; | |
justify-content: center; | |
} | |
.stPageLink:hover { | |
transform: scale(1.03); /* Slight scale up on hover */ | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); /* Increased shadow on hover */ | |
} | |
.stPageLink span { /* Style the label text */ | |
margin-left: 5px; /* Space between icon and text */ | |
} | |
/* Dark theme adjustments (optional, if needed for better contrast) */ | |
/* Consider using Streamlit's theme variables if possible for a more robust solution */ | |
/* For simplicity, this example uses fixed colors that should work reasonably well */ | |
/* [data-theme="dark"] .stPageLink { | |
} | |
[data-theme="dark"] .stPageLink:hover { | |
} */ | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
with st.popover("Click here to see Database Schema", use_container_width=True): | |
uploaded_df_schema = st.session_state.get("uploaded_df_schema", False) | |
choice = st.segmented_control( | |
"Choose", | |
["Default DB", "Uploaded Files"], | |
label_visibility="collapsed", | |
disabled=uploaded_df_schema == False, | |
default="Default DB" if uploaded_df_schema == False else "Uploaded Files", | |
) | |
if uploaded_df_schema is False: | |
st.markdown( | |
"""> You can also upload your own files, to get your schemas. You can then use those schemas to cross-check our answers with ChatGpt/Gemini/Claude (Preferred if the Question is very Complex). You can run the queries directly with our Manual SQL Executer😊. | |
- Ask Questions | |
- Run Queries: automatic + manual | |
- Download Results """ | |
) | |
st.page_link( | |
page="pages/3 📂File Upload for SQL.py", | |
label="Upload your own CSV or Excel files", | |
icon="📜", | |
) | |
schema = load_defaultdb_schema_text() | |
st.markdown(schema, unsafe_allow_html=True) | |
elif choice == "Default DB": | |
schema = load_defaultdb_schema_text() | |
st.markdown(schema, unsafe_allow_html=True) | |
else: | |
pretty_schema, markdown = st.tabs(["Schema", "Copy Schema in Markdown"]) | |
with pretty_schema: | |
st.info( | |
"You can copy this schema, and give it to any state of the art LLM models like (Gemini /ChatGPT /Claude etc) to cross check your answers.\n You can run the queries directly here, by using ***Manual Query Executer*** in the sidebar and download your results 😊", | |
icon="ℹ️", | |
) | |
st.markdown(uploaded_df_schema, unsafe_allow_html=True) | |
with markdown: | |
st.info( | |
"You can copy this schema, and give it to any state of the art LLM models like (Gemini /ChatGPT /Claude etc) to cross check your answers.\n You can run the queries directly here, by using ***Manual Query Executer*** in the sidebar and download your results 😊", | |
icon="ℹ️", | |
) | |
st.markdown(f"```\n{uploaded_df_schema}\n```") | |
col1, col2 = st.columns([2, 1], vertical_alignment="bottom") | |
with col1: | |
st.header("Natural Language to SQL Query Agent🤖") | |
with col2: | |
st.caption("> ***Execute on the Go!*** 🚀 In-Built DuckDB Execution Engine") | |
st.caption( | |
"This is a Qwen2.5-Coder-3B model fine-tuned for SQL queries integrated with langchain for Agentic Workflow. To see the Fine-Tuning code - [click here](https://www.kaggle.com/code/debopamchowdhury/qwen-2-5coder-3b-instruct-finetuning)." | |
) | |
col1, col2, col3 = st.columns([1.5, 2, 1], vertical_alignment="top") | |
with col1: | |
disabled_selection = True | |
if ( | |
"uploaded_dataframes" in st.session_state | |
) and st.session_state.uploaded_dataframes: | |
disabled_selection = False | |
options = ["default_db", "uploaded_files"] | |
selected = st.segmented_control( | |
"Choose", | |
options, | |
selection_mode="single", | |
disabled=disabled_selection, | |
label_visibility="collapsed", | |
default="default_db" if disabled_selection else "uploaded_files", | |
) | |
if not disabled_selection: | |
if selected == "uploaded_files": | |
selected_df = st.session_state.uploaded_dataframes | |
# print(selected_df) | |
use_default_schema = False | |
else: | |
selected_df = default_dfs | |
# print(selected_df) | |
use_default_schema = True | |
if selected_df == default_dfs: | |
with st.popover("Default Database Queries 📚 - Trial"): | |
default_db_questions = load_defaultdb_queries() | |
st.markdown(default_db_questions) | |
with col2: | |
llm_option_radio = st.radio( | |
"Choose LLM Model", | |
["Gemini-2.0-Flash-Exp", "FineTuned Qwen2.5-Coder-3B for SQL"], | |
captions=[ | |
"Used via API", | |
"Run Locally on this Server. Extremely Slow because of Free vCPUs, [Download & Run on your Computer via Ollama](https://ollama.com/debopam/Text-to-SQL__Qwen2.5-Coder-3B-FineTuned)", | |
], | |
label_visibility="collapsed", | |
) | |
if llm_option_radio == "Gemini-2.0-Flash-Exp": | |
llm_option = "gemini" | |
else: | |
llm_option = "qwen" | |
with col3: | |
# Button to refresh the conversation | |
if st.button("Start New Conversation", type="primary"): | |
st.session_state.chat_history = [] | |
st.session_state.conversation_turns = 0 | |
st.rerun() | |
# Initialize chat history in session state | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Initialize conversation turn counter | |
if "conversation_turns" not in st.session_state: | |
st.session_state.conversation_turns = 0 | |
# Set the maximum number of conversation turns | |
MAX_TURNS = 5 | |
# Display existing chat messages | |
for message in st.session_state.chat_history: | |
with st.chat_message(message.type): | |
st.markdown(message.content) | |
if ( | |
isinstance(message, AIMessage) | |
and "response_df" in message.additional_kwargs | |
and message.additional_kwargs["response_df"] is not None | |
and not message.additional_kwargs["response_df"].empty | |
): | |
with st.expander("View SQL-Query Execution Result"): | |
df = message.additional_kwargs["response_df"] | |
# download_csv = convert_df(df) | |
# st.download_button( | |
# label="Download data as CSV", | |
# data=download_csv, | |
# file_name="query_results.csv", | |
# mime="text/csv", | |
# ) | |
# renderer = StreamlitRenderer( | |
# df, | |
# spec_io_mode="rw", | |
# default_tab="data", | |
# appearance="dark", | |
# kernel_computation=True, | |
# ) | |
# renderer.explorer(default_tab="data") | |
st.dataframe(df) | |
st.info(f"Rows x Columns: {df.shape[0]} x {df.shape[1]}") | |
st.subheader("Data Description:") | |
st.markdown(df.describe().T.to_markdown()) | |
st.subheader("Data Types:") | |
st.write(df.dtypes) | |
# Get user input only if the conversation turn limit is not reached | |
if st.session_state.conversation_turns < MAX_TURNS: | |
if prompt := st.chat_input("Ask me a SQL query question"): | |
# Add user message to chat history in session state | |
st.session_state.chat_history.append(HumanMessage(content=prompt)) | |
# Display user message in chat | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
duckdb_result = None | |
# Get assistant response with streaming | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
spinner_text = "" | |
if llm_option == "gemini": | |
spinner_text = ( | |
"Using Gemini-2.0-Flash-Exp to run your query. Please wait...😊" | |
) | |
else: | |
spinner_text = "I know it is taking a lot of time. To run the model I'm using `Free` small vCPUs provided by `HuggingFace Spaces` for deployment. Thank you so much for your patience😊" | |
with st.spinner( | |
spinner_text, | |
): | |
for response_so_far in generate_llm_response( | |
prompt, llm_option, use_default_schema | |
): | |
# Remove <sql> and </sql> tags for streaming display | |
streaming_response = response_so_far.replace("<sql>", "").replace( | |
"</sql>", "" | |
) | |
# Remove duplicate ```sql tags with or without space for streaming display | |
streaming_response = re.sub( | |
r"```sql\s*```sql", "```sql", streaming_response | |
) | |
message_placeholder.markdown(streaming_response + "▌") | |
full_response = response_so_far | |
# Remove <sql> and </sql> tags from the full response | |
full_response = full_response.replace("<sql>", "").replace("</sql>", "") | |
# Remove duplicate ```sql tags with or without space from the full response | |
full_response = re.sub(r"```sql\s*```sql", "```sql", full_response) | |
# Remove trailing duplicate ``` tags from the full response | |
full_response = re.sub(r"[\s\n]*`+$", "```", full_response) | |
message_placeholder.markdown(full_response) | |
# st.text(extract_sql_command(full_response)) | |
sql_command = extract_sql_command(full_response) | |
# dataframe_html = None | |
if sql_command: | |
# st.text("Extracted SQL Command:") | |
# st.code(sql_command, language="sql") | |
duckdb_result = execute_sql_duckdb(sql_command, selected_df) | |
if duckdb_result is not None: | |
st.text("Query Execution Result:") | |
with st.expander("View Result"): | |
# st.dataframe(duckdb_result) | |
st.dataframe(duckdb_result) | |
st.info( | |
f"Rows x Columns: {duckdb_result.shape[0]} x {duckdb_result.shape[1]}" | |
) | |
st.subheader("Data Description:") | |
st.markdown(duckdb_result.describe().T.to_markdown()) | |
st.subheader("Data Types:") | |
st.write(duckdb_result.dtypes) | |
# renderer = StreamlitRenderer( | |
# duckdb_result, | |
# spec_io_mode="rw", | |
# default_tab="data", | |
# appearance="dark", | |
# kernel_computation=True, | |
# ) | |
# renderer.explorer(default_tab="data") | |
else: | |
# st.warning("No SQL command found in the response.") | |
pass | |
# Add assistant response to chat history in session state | |
st.session_state.chat_history.append( | |
AIMessage( | |
content=full_response, | |
additional_kwargs={"response_df": duckdb_result}, | |
) | |
) | |
# Increment the conversation turn counter | |
st.session_state.conversation_turns += 1 | |
else: | |
st.warning( | |
"Maximum number of questions reached. Please click 'Start New Conversation' to continue." | |
) | |
st.chat_input( | |
"Ask me a SQL query question", disabled=True | |
) # Disable the input field | |
with st.sidebar: | |
st.caption("Made with ❤️ by @Debopam_Chowdhury") | |