vindruid
clean code
3ab62c4 unverified
import os
import gradio as gr
import time
import uuid
from typing import List, TypedDict, Annotated, Optional
from gradio.themes.base import Base
import pandas as pd
import altair as alt
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.message import add_messages
from langgraph.graph import START, END, StateGraph
# Global df for sharing between functions
df = pd.DataFrame()
# --- Tool ---
@tool
def describe_schema() -> str:
"""
Describe the dataframe schema so you will have context how to processing it.
Do this before generating any plot if you are not sure about the columns,
can skip this if you already know about the columns and data types.
By knowing the schema, you can better understand how to instruct the plot creation.
"""
return str(df.dtypes)
@tool
def generate_plot_code(plot_instruction: str) -> dict:
"""
Given a plot_instruction not the direct Python code,
generate Python code that:
1. Performs aggregation/transformation on `df` (store in `df_agg`)
2. Generates a Altair plot from `df_agg` (store in `fig`)
Args:
plot_instruction (str): A description of the plot to generate, e.g. "Bar chart of total revenue by region".
Returns:
dict: A dictionary containing:
- `plot_instruction`: The original plot instruction.
- `code`: The generated Python code as a string.
- `chart`: The Altair chart object.
- `df_agg`: The aggregated DataFrame used for the plot.
"""
promt_generate_plot_code = """
You are a Python assistant. A pandas DataFrame `df` is available.
Your task:
1. Perform any necessary data processing or aggregation based on this request: "{plot_instruction}"
- Store the final df_agg in a variable called `df_agg`.
- When grouping data, always use `.reset_index()` after aggregation so the group keys remain columns in the df_agg.
2. Create a Altair plot from `df_agg`
- Only use the Altair library.
- Assign the chart to a variable named `chart`.
- Do NOT include explanations, comments, or markdown (like ```python).
- Use the existing DataFrame `df` directly.
- Just return executable Python code.
Rules:
- Do NOT create fake/sample data.
- Use only the real `df`.
- must create variable `df_agg` for the aggregated DataFrame.
- must create variable `chart` for the Altair chart.
- always show title and tooltip in the chart.
- No print statements or explanation โ€” just code.
- Be flexible interpreting column names:
- If the plot_instruction uses a partial or common term (e.g. "customer"), find the best matching column(s) in schema (like "customer_name").
- Normalize and expand synonyms or abbreviations to match columns.
- If multiple columns match, pick the most relevant one.
Example result:
import altair as alt
df_agg = df.groupby('region')['sales'].sum().reset_index().sort_values('sales', ascending=False)
chart = alt.Chart(df_agg).mark_bar().encode(
x='region:N',
y='sales:Q',
color=alt.Color('region:N', scale=alt.Scale(scheme='tableau10')),
tooltip=['region', 'sales']
).properties(
title='Top Sales per Region'
).transform_calculate(
text='datum.sales'
).mark_bar(
cornerRadiusTopLeft=3, cornerRadiusTopRight=3
)
"""
promt_generate_plot_code = promt_generate_plot_code.format(plot_instruction=plot_instruction)
try:
response = llm_plot.invoke([HumanMessage(content=promt_generate_plot_code)])
code = response.content.strip()
# Remove markdown fences if present
if code.startswith("```"):
lines = code.split("\n")
if lines[0].startswith("```"):
lines = lines[1:]
if lines[-1].startswith("```"):
lines = lines[:-1]
code = "\n".join(lines).strip()
interpretation = assistant_analysis(code,plot_instruction)
return {
"plot_instruction": plot_instruction,
"code": code,
"interpretation" : interpretation,
}
except Exception as e:
raise RuntimeError(f"Failed to generate plot: {e}")
@tool
def enhance_plot_code(previous_code: str, plot_instruction: str) -> dict:
"""
Given a previous code and plot_instruction not the direct Python code,
enhance Python code for graph that:
1. Performs aggregation/transformation on `df` (store in `df_agg`)
2. Generates a Altair plot from `df_agg` (store in `fig`)
3. Enhances the previous code based on the new plot_instruction
Args:
plot_instruction (str): A description of the plot to generate, e.g. "Bar chart of total revenue by region".
Returns:
dict: A dictionary containing:
- `plot_instruction`: The original plot instruction.
- `code`: The generated Python code as a string.
By running this tool, you are assume already show the plot to user, so do not say you cannot display the plot.
"""
prompt_enhance_plot_code = """
You are a Python assistant. A pandas DataFrame `df` is available.
You know the previous code that already generated a plot,
"{previous_code}"
Your task:
Enhance previous code based on this request:
"{plot_instruction}"
Rules:
- Do NOT create fake/sample data.
- Use only the real `df`.
- must create variable `df_agg` for the aggregated DataFrame.
- must create variable `chart` for the Altair chart.
- always show title and tooltip in the chart.
- No print statements or explanation โ€” just code.
- Be flexible interpreting column names:
- If the plot_instruction uses a partial or common term (e.g. "customer"), find the best matching column(s) in schema (like "customer_name").
- Normalize and expand synonyms or abbreviations to match columns.
- If multiple columns match, pick the most relevant one.
Example result:
import altair as alt
df_agg = df.groupby('region')['sales'].sum().reset_index().sort_values('sales', ascending=False)
chart = alt.Chart(df_agg).mark_bar().encode(
x='region:N',
y='sales:Q',
color=alt.Color('region:N', scale=alt.Scale(scheme='tableau10')),
tooltip=['region', 'sales']
).properties(
title='Top Sales per Region'
).transform_calculate(
text='datum.sales'
).mark_bar(
cornerRadiusTopLeft=3, cornerRadiusTopRight=3
)
"""
prompt_enhance_plot_code = prompt_enhance_plot_code.format(previous_code = previous_code, plot_instruction=plot_instruction)
try:
response = llm_plot.invoke([HumanMessage(content=prompt_enhance_plot_code)])
code = response.content.strip()
# Remove markdown fences if present
if code.startswith("```"):
lines = code.split("\n")
if lines[0].startswith("```"):
lines = lines[1:]
if lines[-1].startswith("```"):
lines = lines[:-1]
code = "\n".join(lines).strip()
return {
"plot_instruction": plot_instruction,
"code": code,
"interpretation":" "
}
except Exception as e:
raise RuntimeError(f"Failed to generate plot: {e}")
def generate_plot_from_code(code: str):
local_scope = {"df": df, "alt": alt}
exec(code, {}, local_scope)
if "chart" not in local_scope:
raise ValueError("No valid `chart` was generated.")
return local_scope["chart"]
def generate_df_agg_from_code(code: str):
local_scope = {"df": df, "alt": alt}
exec(code, {}, local_scope)
if "chart" not in local_scope:
raise ValueError("No valid `chart` was generated.")
return local_scope["df_agg"]
tools = [
describe_schema,
generate_plot_code,
enhance_plot_code,
]
# --- LLM Setup ---
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0.5,
max_tokens=None,
timeout=None,
max_retries=2,
)
llm = llm.bind_tools(tools)
llm_analysis = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0.5,
max_tokens=None,
timeout=None,
max_retries=2,
)
llm_plot = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0.5,
max_tokens=None,
timeout=None,
max_retries=2,
)
# --- LangGraph State Setup ---
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
assigned_tools: Optional[List[str]] # List of tools assigned to the agent
table_schema: Optional[str] # Schema of the DataFrame, assume only one table
plots: List[dict] # List of generated plots
sys_msg = SystemMessage(content="""
You are a helpful assistant named Terloka Bro who works for creating plots.
you can run tools such as `describe_schema` to understand the dataframe schema,
and `generate_plot_code` to generate Python code that creates a plot using the Altair library.
Please do `describe_schema` first then `generate_plot_code` to create a plot, do not call those two function at the same time.
No need to say if the chart cannot be displayed, because it already handled in the application.
You already have access to a DataFrame called `df`
""")
#--- Assistant Functions ---
def assistant(state: AgentState) -> AgentState:
schema_output = describe_schema.invoke(df)
res = llm.invoke([sys_msg] + [HumanMessage(content="show your scheme")] + [AIMessage(content=schema_output)] + [ToolMessage(content=schema_output, name="describe_schema", id=str(uuid.uuid4()), tool_call_id=str(uuid.uuid4()))] + state["messages"])
state["messages"].append(res)
assigned_tools = []
if isinstance(res, AIMessage):
if res.tool_calls:
for tool_call in res.tool_calls:
assigned_tools.append(tool_call)
return {
"messages": state["messages"],
"assigned_tools": assigned_tools,
"table_schema": state.get("table_schema", []),
"plots": state.get("plots", [])
}
sys_msg_analysis = SystemMessage(content="""
You are given an aggregated `df_agg` dataframe and `instruction`. Your are required to analyze the finding base on the given data.
""")
def assistant_analysis(plot_code,plot_instruction):
df_agg_temp = generate_df_agg_from_code(plot_code)
df_agg_result = df_agg_temp.to_dict(orient='list')
prompt_analysis = f"""
You are given aggregation data result:
```
{df_agg_result}
```
By given analysis requirement :
```
{plot_instruction}
```
The expect output:
- Only provide insight and findings base on the instruction and result
- Do NOT give suggest plot code
- Do NOT explain the technical of the chart information
"""
res = llm_analysis.invoke([sys_msg_analysis] + [HumanMessage(content=prompt_analysis)])
analysis_str = res.content
return analysis_str
def clean_runned_tools(state: AgentState, tool_name: str) -> AgentState:
"""Clean the runned tools from the state"""
if state["assigned_tools"]:
removed_list = state["assigned_tools"].copy()
for tool_call in state["assigned_tools"]:
if tool_call.get('name') == tool_name:
removed_list.remove(tool_call)
break
state["assigned_tools"] = removed_list
return state
def do_describe_chema(state: AgentState) -> AgentState:
"""Perform the describe schema using the assigned tool"""
if state["assigned_tools"]:
for tool_call in state["assigned_tools"]:
if tool_call.get('name') == "describe_schema":
tool_res = describe_schema.invoke(tool_call['args']) # Call the tool with the arguments
state["table_schema"] = tool_res
tool_message = ToolMessage(
content=str(tool_res), # Convert the result to string
id =str(uuid.uuid4()), # Generate a unique ID for the tool message
name=tool_call['name'], # Use the tool name from the tool call
tool_call_id=tool_call['id'] # Use the tool call ID for tracking
)
state["messages"].append(tool_message)
break
""" delete the runned tool call from the state """
state = clean_runned_tools(state, "describe_schema")
return state
def do_generate_plot_code(state: AgentState) -> AgentState:
"""Perform the plot generation using the assigned tool"""
if state["assigned_tools"]:
for tool_call in state["assigned_tools"]:
if tool_call.get('name') == "generate_plot_code":
tool_res = generate_plot_code.invoke(tool_call['args']) # Call the tool with the arguments
if "plots" not in state:
state["plots"] = []
state["plots"].append(tool_res)
tool_message = ToolMessage(
content=str(tool_res['code']), # Convert the result to string, but only the chart
id =str(uuid.uuid4()), # Generate a unique ID for the tool message
name=tool_call['name'], # Use the tool name from the tool call
tool_call_id=tool_call['id'] # Use the tool call ID for tracking
)
state["messages"].append(tool_message)
break
""" delete the runned tool call from the state """
state = clean_runned_tools(state, "generate_plot_code")
return state
def do_enhance_plot_code(state: AgentState) -> AgentState:
"""Perform the plot generation using the assigned tool"""
if state["assigned_tools"]:
for tool_call in state["assigned_tools"]:
if tool_call.get('name') == "enhance_plot_code":
tool_res = enhance_plot_code.invoke(tool_call['args']) # Call the tool with the arguments
if "plots" not in state:
state["plots"] = []
state["plots"].append(tool_res)
tool_message = ToolMessage(
content=str(tool_res['code']), # Convert the result to string, but only the chart
id =str(uuid.uuid4()), # Generate a unique ID for the tool message
name=tool_call['name'], # Use the tool name from the tool call
tool_call_id=tool_call['id'] # Use the tool call ID for tracking
)
state["messages"].append(tool_message)
break
""" delete the runned tool call from the state """
state = clean_runned_tools(state, "enhance_plot_code")
return state
def route_to_tool(state: AgentState) -> str:
"""Determine the next step based on assigned tools"""
if state["assigned_tools"]:
for tool_call in state["assigned_tools"]:
if tool_call.get('name') == "describe_schema":
return "describe_schema"
elif tool_call.get('name') == "generate_plot_code":
return "generate_plot_code"
elif tool_call.get('name') == "enhance_plot_code":
return "enhance_plot_code"
return "no_tool_required"
def route_from_tool(state: AgentState) -> str:
"""Determine the next step based on assigned tools"""
if state["assigned_tools"]:
for tool_call in state["assigned_tools"]:
if tool_call.get('name') == "generate_plot_code":
return "generate_plot_code"
return "assistant"
def build_graph():
builder = StateGraph(AgentState)
builder.add_node("Assistant", assistant)
builder.add_node("Describe Schema", do_describe_chema)
builder.add_node("Generate Plot", do_generate_plot_code)
builder.add_node("Enhance Plot", do_enhance_plot_code)
edges_to_tool = {
"describe_schema": "Describe Schema",
"generate_plot_code": "Generate Plot",
"enhance_plot_code": "Enhance Plot",
"no_tool_required": END,
}
edges_from_tool = {
"generate_plot_code": "Generate Plot",
"assistant": "Assistant",
}
builder.add_edge(START, "Assistant")
builder.add_conditional_edges("Assistant", route_to_tool, edges_to_tool)
builder.add_conditional_edges("Describe Schema", route_from_tool, edges_from_tool)
builder.add_conditional_edges("Generate Plot", route_from_tool, edges_from_tool)
builder.add_conditional_edges("Enhance Plot", route_from_tool, edges_from_tool)
builder.add_edge("Assistant", END)
memory = InMemorySaver()
return builder.compile(checkpointer=memory)
react_graph = build_graph()
config = {"configurable": {"thread_id": 123, "session": 100}}
# --- Data Exploration Functions ---
def to_snake_case(name):
return name.lower().replace(' ', '_').replace('-', '_')
def get_info_df(df):
info_df = pd.DataFrame({
"column": df.columns,
"non_null_count": df.notnull().sum().values,
"dtype": df.dtypes.astype(str).values
})
return info_df
def summarize_nulls(df):
null_summary = df.isnull().sum().reset_index()
null_summary.columns = ['column', 'null_count']
null_summary['percent'] = (null_summary['null_count'] / len(df)) * 100
return null_summary[null_summary['null_count'] > 0]
def summarize_duplicates(df):
return pd.DataFrame({
"duplicated_rows": [df.duplicated().sum()],
"total_rows": [len(df)],
"percent_duplicated": [100 * df.duplicated().sum() / len(df)]
})
def load_example_dataset(name):
global df
try:
if name == "iris":
df = pd.read_csv("https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv")
elif name == "titanic":
df = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/refs/heads/master/titanic.csv")
elif name == "superstore":
df = pd.read_excel("https://public.tableau.com/app/sample-data/sample_-_superstore.xls")
else:
raise ValueError("Unknown dataset name.")
df.columns = [col.lower().replace(" ", "_") for col in df.columns]
null_summary = summarize_nulls(df)
dup_summary = summarize_duplicates(df)
return (
gr.update(visible=True), # Show main tabs
gr.update(visible=False), # Hide warning
gr.update(visible=False), # Hide iris button
gr.update(visible=False), # Hide titanic button
gr.update(visible=False), # Hide superstore button
gr.update(visible=False), # Hide upload button
gr.update(visible=False), # Hide instruction button
gr.update(visible=False), # Hide example button
df.describe().reset_index(),
get_info_df(df),
df.head(),
null_summary,
dup_summary
)
except Exception as e:
raise gr.Error(f"Failed to load dataset: {e}")
def handle_upload(file):
global df
if file is None or file.name == "":
return (
gr.update(visible=False), # Hide main tabs
gr.update(visible=True), # Show warning
pd.DataFrame(), "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
)
try:
df = pd.read_csv(file) if file.name.endswith(".csv") else pd.read_excel(file)
except Exception as e:
raise gr.Error(f"Failed to read the file: {e}")
df.columns = [to_snake_case(col) for col in df.columns]
df = df
null_summary = summarize_nulls(df)
dup_summary = summarize_duplicates(df)
# Rebuild the graph and reset the config
global react_graph
react_graph = build_graph() # Rebuild graph to reset state
global config
config = {"configurable": {"thread_id": str(uuid.uuid4()), "session": str(uuid.uuid4())}}
return (
gr.update(visible=True), # Show main tabs
gr.update(visible=False), # Hide warning
gr.update(visible=False), # Hide iris button
gr.update(visible=False), # Hide titanic button
gr.update(visible=False), # Hide superstore button
gr.update(visible=True), # Hide upload button
gr.update(visible=False), # Hide instruction button
gr.update(visible=False), # Hide example button
df.describe().reset_index(),
get_info_df(df),
df.head(),
null_summary,
dup_summary
)
def refresh_graph():
global react_graph
react_graph = build_graph() # Rebuild graph to reset state
global config
config = {"configurable": {"thread_id": str(uuid.uuid4()), "session": str(uuid.uuid4())}}
def respond(message, chat_history):
chat_history = []
res = react_graph.invoke(
{"messages": [HumanMessage(content=message)]}
, config=config)
for msg in res["messages"]:
msg.pretty_print()
if isinstance(msg, HumanMessage):
chat_history.append({"role": "user", "content": msg.content})
if isinstance(msg, AIMessage):
ai_response = msg.content
chat_history.append({"role": "assistant", "content": ai_response})
if isinstance(msg, ToolMessage):
if msg.name == "generate_plot_code":
plot_result = generate_plot_from_code(msg.content)
chat_history.append({"role": "assistant", "content": gr.Plot(plot_result)})
chat_history.append({"role": "assistant", "content": res["plots"][-1].get('interpretation')})
if msg.name == "enhance_plot_code":
plot_result = generate_plot_from_code(msg.content)
chat_history.append({"role": "assistant", "content": gr.Plot(plot_result)})
time.sleep(1)
return "", chat_history
# --- Gradio UI ---
my_theme = gr.Theme.from_hub("NoCrypt/miku")
with gr.Blocks(theme=my_theme) as demo:
demo.load(refresh_graph, inputs=None, outputs=None)
gr.HTML("""
<style>
body, .container, h1, h2, h3, p, span {
font-family: "IBM Plex Sans";
}
#instruction blockquote {
margin: 12px auto 0 auto;
padding: 12px 16px;
border-radius: 6px;
font-size: 14px;
max-width: 7000px;
}
#chatbot_hint {
margin: 12px auto 0 auto;
padding: 12px 16px;
border-radius: 6px;
font-size: 14px;
max-width: 7000px;
}
@keyframes fadeInTitle {
0% {
opacity: 0;
transform: translateY(-10px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}
.container {
padding: 24px;
border-radius: 16px;
box-shadow: 0 2px 30px rgba(42, 86, 198, 0.12);
text-align: center;
transition: box-shadow 0.3s ease;
margin-bottom: 12px;
}
.subtitle {
font-size: 16px;
margin-top: -6px;
}
</style>
<div class="container">
<h1>
<span style="font-size: 30px;">๐ŸŽฏ</span>
<span class="title-gradient">AI Chat to Visual</span>
</h1>
<p class="subtitle">Your gateway to smarter decisions through travel data.</p>
</div>
""")
instruction_box = gr.Markdown(
"> Upload a file to get started. Supported formats: `.csv`, `.xls`, `.xlsx`",
elem_id="instruction"
)
warning_box = gr.Markdown("โš ๏ธ **You can't proceed without uploading your files first**", visible=True)
upload_btn = gr.File(file_types=[".csv", ".xls", ".xlsx"], label="๐Ÿ“ Upload File")
example_box = gr.Markdown("### Or use an example dataset:")
with gr.Row():
iris_btn = gr.Button("๐ŸŒธ Load Iris")
titanic_btn = gr.Button("๐Ÿšข Load Titanic")
superstore_btn = gr.Button("๐Ÿช Load Superstore")
with gr.Tabs(visible=False) as main_tabs:
with gr.Tab("๐Ÿค– ChatBot for Viz"):
gr.Markdown(
"๐Ÿ‘‰ Want to understand your data first? Go to the Data Exploration tab first!",
elem_id="chatbot_hint"
)
chatbot = gr.Chatbot(type="messages", label="Data Chatbot", elem_id="chatbot")
msg = gr.Textbox(label="",elem_id="chat_input", container=False, placeholder="Ask me anything about your data...")
msg.submit(respond, [msg, chatbot], [msg, chatbot])
with gr.Tab("๐Ÿ“Š Data Exploration"):
with gr.Column():
with gr.Accordion("๐Ÿงฎ Data Description", open=True):
describe_output = gr.DataFrame()
with gr.Accordion("๐Ÿ“‹ Data Info", open=True):
info_output = gr.DataFrame()
with gr.Accordion("๐Ÿ‘๏ธ Preview Data", open=False):
head_output = gr.DataFrame()
with gr.Accordion("๐Ÿงผ Null Detection", open=False):
null_output = gr.DataFrame()
with gr.Accordion("๐Ÿ“Ž Duplicate Check", open=False):
dup_output = gr.DataFrame()
# Removed Histogram section here
gr.Markdown("---")
gr.Markdown("๐Ÿ› ๏ธ Built with โค๏ธ by **Terloka Bros**", elem_id="footer")
upload_btn.change(
fn=handle_upload,
inputs=upload_btn,
outputs=[
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box,
describe_output, info_output,
head_output, null_output,
dup_output
]
)
iris_btn.click(
fn=lambda: load_example_dataset("iris"),
outputs=[
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box,
describe_output, info_output,
head_output, null_output,
dup_output
]
)
titanic_btn.click(
fn=lambda: load_example_dataset("titanic"),
outputs=[
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box,
describe_output, info_output,
head_output, null_output,
dup_output
]
)
superstore_btn.click(
fn=lambda: load_example_dataset("superstore"),
outputs=[
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box,
describe_output, info_output,
head_output, null_output,
dup_output
]
)
demo.launch()