Spaces:
Runtime error
Runtime error
File size: 13,074 Bytes
9820eac 69d4b1d 9820eac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
# main.py
import re
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from utils.llm_logic import generate_llm_response
from utils.sql_utils import (
extract_sql_command,
load_defaultdb_schema_text,
load_defaultdb_queries,
load_data,
)
from utils.handle_sql_commands import execute_sql_duckdb
st.set_page_config(
page_title="Text-to-SQL Agent",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded",
)
default_db_questions = {}
default_dfs = load_data()
selected_df = default_dfs
use_default_schema = True
llm_option = "gemini"
st.markdown(
"""
<style>
/* Base styles for both themes */
.stPageLink {
background-image: linear-gradient(to right, #007BFF, #6610F2); /* Gradient background */
color: white !important; /* Ensure text is readable on the gradient */
padding: 12px 20px !important; /* Slightly larger padding */
border-radius: 8px !important; /* More rounded corners */
border: none !important; /* Remove default border */
text-decoration: none !important;
font-weight: 500 !important; /* Slightly lighter font weight */
transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out; /* Smooth transitions */
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.15); /* Subtle shadow for depth */
display: inline-flex;
align-items: center;
justify-content: center;
}
.stPageLink:hover {
transform: scale(1.03); /* Slight scale up on hover */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); /* Increased shadow on hover */
}
.stPageLink span { /* Style the label text */
margin-left: 5px; /* Space between icon and text */
}
/* Dark theme adjustments (optional, if needed for better contrast) */
/* Consider using Streamlit's theme variables if possible for a more robust solution */
/* For simplicity, this example uses fixed colors that should work reasonably well */
/* [data-theme="dark"] .stPageLink {
}
[data-theme="dark"] .stPageLink:hover {
} */
</style>
""",
unsafe_allow_html=True,
)
with st.popover("Click here to see Database Schema", use_container_width=True):
uploaded_df_schema = st.session_state.get("uploaded_df_schema", False)
choice = st.segmented_control(
"Choose",
["Default DB", "Uploaded Files"],
label_visibility="collapsed",
disabled=uploaded_df_schema == False,
default="Default DB" if uploaded_df_schema == False else "Uploaded Files",
)
if uploaded_df_schema is False:
st.markdown(
"""> You can also upload your own files, to get your schemas. You can then use those schemas to cross-check our answers with ChatGpt/Gemini/Claude (Preferred if the Question is very Complex). You can run the queries directly with our Manual SQL Executer😊.
- Ask Questions
- Run Queries: automatic + manual
- Download Results """
)
st.page_link(
page="pages/3 📂File Upload for SQL.py",
label="Upload your own CSV or Excel files",
icon="📜",
)
schema = load_defaultdb_schema_text()
st.markdown(schema, unsafe_allow_html=True)
elif choice == "Default DB":
schema = load_defaultdb_schema_text()
st.markdown(schema, unsafe_allow_html=True)
else:
pretty_schema, markdown = st.tabs(["Schema", "Copy Schema in Markdown"])
with pretty_schema:
st.info(
"You can copy this schema, and give it to any state of the art LLM models like (Gemini /ChatGPT /Claude etc) to cross check your answers.\n You can run the queries directly here, by using ***Manual Query Executer*** in the sidebar and download your results 😊",
icon="ℹ️",
)
st.markdown(uploaded_df_schema, unsafe_allow_html=True)
with markdown:
st.info(
"You can copy this schema, and give it to any state of the art LLM models like (Gemini /ChatGPT /Claude etc) to cross check your answers.\n You can run the queries directly here, by using ***Manual Query Executer*** in the sidebar and download your results 😊",
icon="ℹ️",
)
st.markdown(f"```\n{uploaded_df_schema}\n```")
col1, col2 = st.columns([2, 1], vertical_alignment="bottom")
with col1:
st.header("Natural Language to SQL Query Agent🤖")
with col2:
st.caption("> ***Execute on the Go!*** 🚀 In-Built DuckDB Execution Engine")
st.caption(
"This is a Qwen2.5-Coder-3B model fine-tuned for SQL queries integrated with langchain for Agentic Workflow. To see the Fine-Tuning code - [click here](https://www.kaggle.com/code/debopamchowdhury/qwen-2-5coder-3b-instruct-finetuning)."
)
col1, col2, col3 = st.columns([1.5, 2, 1], vertical_alignment="top")
with col1:
disabled_selection = True
if (
"uploaded_dataframes" in st.session_state
) and st.session_state.uploaded_dataframes:
disabled_selection = False
options = ["default_db", "uploaded_files"]
selected = st.segmented_control(
"Choose",
options,
selection_mode="single",
disabled=disabled_selection,
label_visibility="collapsed",
default="default_db" if disabled_selection else "uploaded_files",
)
if not disabled_selection:
if selected == "uploaded_files":
selected_df = st.session_state.uploaded_dataframes
# print(selected_df)
use_default_schema = False
else:
selected_df = default_dfs
# print(selected_df)
use_default_schema = True
if selected_df == default_dfs:
with st.popover("Default Database Queries 📚 - Trial"):
default_db_questions = load_defaultdb_queries()
st.markdown(default_db_questions)
with col2:
llm_option_radio = st.radio(
"Choose LLM Model",
["Gemini-2.0-Flash-Exp", "FineTuned Qwen2.5-Coder-3B for SQL"],
captions=[
"Used via API",
"Run Locally on this Server. Extremely Slow because of Free vCPUs, [Download & Run on your Computer via Ollama](https://ollama.com/debopam/Text-to-SQL__Qwen2.5-Coder-3B-FineTuned)",
],
label_visibility="collapsed",
)
if llm_option_radio == "Gemini-2.0-Flash-Exp":
llm_option = "gemini"
else:
llm_option = "qwen"
with col3:
# Button to refresh the conversation
if st.button("Start New Conversation", type="primary"):
st.session_state.chat_history = []
st.session_state.conversation_turns = 0
st.rerun()
# Initialize chat history in session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize conversation turn counter
if "conversation_turns" not in st.session_state:
st.session_state.conversation_turns = 0
# Set the maximum number of conversation turns
MAX_TURNS = 5
# Display existing chat messages
for message in st.session_state.chat_history:
with st.chat_message(message.type):
st.markdown(message.content)
if (
isinstance(message, AIMessage)
and "response_df" in message.additional_kwargs
and message.additional_kwargs["response_df"] is not None
and not message.additional_kwargs["response_df"].empty
):
with st.expander("View SQL-Query Execution Result"):
df = message.additional_kwargs["response_df"]
# download_csv = convert_df(df)
# st.download_button(
# label="Download data as CSV",
# data=download_csv,
# file_name="query_results.csv",
# mime="text/csv",
# )
# renderer = StreamlitRenderer(
# df,
# spec_io_mode="rw",
# default_tab="data",
# appearance="dark",
# kernel_computation=True,
# )
# renderer.explorer(default_tab="data")
st.dataframe(df)
st.info(f"Rows x Columns: {df.shape[0]} x {df.shape[1]}")
st.subheader("Data Description:")
st.markdown(df.describe().T.to_markdown())
st.subheader("Data Types:")
st.write(df.dtypes)
# Get user input only if the conversation turn limit is not reached
if st.session_state.conversation_turns < MAX_TURNS:
if prompt := st.chat_input("Ask me a SQL query question"):
# Add user message to chat history in session state
st.session_state.chat_history.append(HumanMessage(content=prompt))
# Display user message in chat
with st.chat_message("user"):
st.markdown(prompt)
duckdb_result = None
# Get assistant response with streaming
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
spinner_text = ""
if llm_option == "gemini":
spinner_text = (
"Using Gemini-2.0-Flash-Exp to run your query. Please wait...😊"
)
else:
spinner_text = "I know it is taking a lot of time. To run the model I'm using `Free` small vCPUs provided by `HuggingFace Spaces` for deployment. Thank you so much for your patience😊"
with st.spinner(
spinner_text,
):
for response_so_far in generate_llm_response(
prompt, llm_option, use_default_schema
):
# Remove <sql> and </sql> tags for streaming display
streaming_response = response_so_far.replace("<sql>", "").replace(
"</sql>", ""
)
# Remove duplicate ```sql tags with or without space for streaming display
streaming_response = re.sub(
r"```sql\s*```sql", "```sql", streaming_response
)
message_placeholder.markdown(streaming_response + "▌")
full_response = response_so_far
# Remove <sql> and </sql> tags from the full response
full_response = full_response.replace("<sql>", "").replace("</sql>", "")
# Remove duplicate ```sql tags with or without space from the full response
full_response = re.sub(r"```sql\s*```sql", "```sql", full_response)
# Remove trailing duplicate ``` tags from the full response
full_response = re.sub(r"[\s\n]*`+$", "```", full_response)
message_placeholder.markdown(full_response)
# st.text(extract_sql_command(full_response))
sql_command = extract_sql_command(full_response)
# dataframe_html = None
if sql_command:
# st.text("Extracted SQL Command:")
# st.code(sql_command, language="sql")
duckdb_result = execute_sql_duckdb(sql_command, selected_df)
if duckdb_result is not None:
st.text("Query Execution Result:")
with st.expander("View Result"):
# st.dataframe(duckdb_result)
st.dataframe(duckdb_result)
st.info(
f"Rows x Columns: {duckdb_result.shape[0]} x {duckdb_result.shape[1]}"
)
st.subheader("Data Description:")
st.markdown(duckdb_result.describe().T.to_markdown())
st.subheader("Data Types:")
st.write(duckdb_result.dtypes)
# renderer = StreamlitRenderer(
# duckdb_result,
# spec_io_mode="rw",
# default_tab="data",
# appearance="dark",
# kernel_computation=True,
# )
# renderer.explorer(default_tab="data")
else:
# st.warning("No SQL command found in the response.")
pass
# Add assistant response to chat history in session state
st.session_state.chat_history.append(
AIMessage(
content=full_response,
additional_kwargs={"response_df": duckdb_result},
)
)
# Increment the conversation turn counter
st.session_state.conversation_turns += 1
else:
st.warning(
"Maximum number of questions reached. Please click 'Start New Conversation' to continue."
)
st.chat_input(
"Ask me a SQL query question", disabled=True
) # Disable the input field
with st.sidebar:
st.caption("Made with ❤️ by @Debopam_Chowdhury")
|