import os import gradio as gr import time import uuid from typing import List, TypedDict, Annotated, Optional from gradio.themes.base import Base import pandas as pd import altair as alt from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ToolMessage from langchain_core.tools import tool from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph.message import add_messages from langgraph.graph import START, END, StateGraph # Global df for sharing between functions df = pd.DataFrame() # --- Tool --- @tool def describe_schema() -> str: """ Describe the dataframe schema so you will have context how to processing it. Do this before generating any plot if you are not sure about the columns, can skip this if you already know about the columns and data types. By knowing the schema, you can better understand how to instruct the plot creation. """ return str(df.dtypes) @tool def generate_plot_code(plot_instruction: str) -> dict: """ Given a plot_instruction not the direct Python code, generate Python code that: 1. Performs aggregation/transformation on `df` (store in `df_agg`) 2. Generates a Altair plot from `df_agg` (store in `fig`) Args: plot_instruction (str): A description of the plot to generate, e.g. "Bar chart of total revenue by region". Returns: dict: A dictionary containing: - `plot_instruction`: The original plot instruction. - `code`: The generated Python code as a string. - `chart`: The Altair chart object. - `df_agg`: The aggregated DataFrame used for the plot. """ promt_generate_plot_code = """ You are a Python assistant. A pandas DataFrame `df` is available. Your task: 1. Perform any necessary data processing or aggregation based on this request: "{plot_instruction}" - Store the final df_agg in a variable called `df_agg`. - When grouping data, always use `.reset_index()` after aggregation so the group keys remain columns in the df_agg. 2. Create a Altair plot from `df_agg` - Only use the Altair library. - Assign the chart to a variable named `chart`. - Do NOT include explanations, comments, or markdown (like ```python). - Use the existing DataFrame `df` directly. - Just return executable Python code. Rules: - Do NOT create fake/sample data. - Use only the real `df`. - must create variable `df_agg` for the aggregated DataFrame. - must create variable `chart` for the Altair chart. - always show title and tooltip in the chart. - No print statements or explanation โ just code. - Be flexible interpreting column names: - If the plot_instruction uses a partial or common term (e.g. "customer"), find the best matching column(s) in schema (like "customer_name"). - Normalize and expand synonyms or abbreviations to match columns. - If multiple columns match, pick the most relevant one. Example result: import altair as alt df_agg = df.groupby('region')['sales'].sum().reset_index().sort_values('sales', ascending=False) chart = alt.Chart(df_agg).mark_bar().encode( x='region:N', y='sales:Q', color=alt.Color('region:N', scale=alt.Scale(scheme='tableau10')), tooltip=['region', 'sales'] ).properties( title='Top Sales per Region' ).transform_calculate( text='datum.sales' ).mark_bar( cornerRadiusTopLeft=3, cornerRadiusTopRight=3 ) """ promt_generate_plot_code = promt_generate_plot_code.format(plot_instruction=plot_instruction) try: response = llm_plot.invoke([HumanMessage(content=promt_generate_plot_code)]) code = response.content.strip() # Remove markdown fences if present if code.startswith("```"): lines = code.split("\n") if lines[0].startswith("```"): lines = lines[1:] if lines[-1].startswith("```"): lines = lines[:-1] code = "\n".join(lines).strip() interpretation = assistant_analysis(code,plot_instruction) return { "plot_instruction": plot_instruction, "code": code, "interpretation" : interpretation, } except Exception as e: raise RuntimeError(f"Failed to generate plot: {e}") @tool def enhance_plot_code(previous_code: str, plot_instruction: str) -> dict: """ Given a previous code and plot_instruction not the direct Python code, enhance Python code for graph that: 1. Performs aggregation/transformation on `df` (store in `df_agg`) 2. Generates a Altair plot from `df_agg` (store in `fig`) 3. Enhances the previous code based on the new plot_instruction Args: plot_instruction (str): A description of the plot to generate, e.g. "Bar chart of total revenue by region". Returns: dict: A dictionary containing: - `plot_instruction`: The original plot instruction. - `code`: The generated Python code as a string. By running this tool, you are assume already show the plot to user, so do not say you cannot display the plot. """ prompt_enhance_plot_code = """ You are a Python assistant. A pandas DataFrame `df` is available. You know the previous code that already generated a plot, "{previous_code}" Your task: Enhance previous code based on this request: "{plot_instruction}" Rules: - Do NOT create fake/sample data. - Use only the real `df`. - must create variable `df_agg` for the aggregated DataFrame. - must create variable `chart` for the Altair chart. - always show title and tooltip in the chart. - No print statements or explanation โ just code. - Be flexible interpreting column names: - If the plot_instruction uses a partial or common term (e.g. "customer"), find the best matching column(s) in schema (like "customer_name"). - Normalize and expand synonyms or abbreviations to match columns. - If multiple columns match, pick the most relevant one. Example result: import altair as alt df_agg = df.groupby('region')['sales'].sum().reset_index().sort_values('sales', ascending=False) chart = alt.Chart(df_agg).mark_bar().encode( x='region:N', y='sales:Q', color=alt.Color('region:N', scale=alt.Scale(scheme='tableau10')), tooltip=['region', 'sales'] ).properties( title='Top Sales per Region' ).transform_calculate( text='datum.sales' ).mark_bar( cornerRadiusTopLeft=3, cornerRadiusTopRight=3 ) """ prompt_enhance_plot_code = prompt_enhance_plot_code.format(previous_code = previous_code, plot_instruction=plot_instruction) try: response = llm_plot.invoke([HumanMessage(content=prompt_enhance_plot_code)]) code = response.content.strip() # Remove markdown fences if present if code.startswith("```"): lines = code.split("\n") if lines[0].startswith("```"): lines = lines[1:] if lines[-1].startswith("```"): lines = lines[:-1] code = "\n".join(lines).strip() return { "plot_instruction": plot_instruction, "code": code, "interpretation":" " } except Exception as e: raise RuntimeError(f"Failed to generate plot: {e}") def generate_plot_from_code(code: str): local_scope = {"df": df, "alt": alt} exec(code, {}, local_scope) if "chart" not in local_scope: raise ValueError("No valid `chart` was generated.") return local_scope["chart"] def generate_df_agg_from_code(code: str): local_scope = {"df": df, "alt": alt} exec(code, {}, local_scope) if "chart" not in local_scope: raise ValueError("No valid `chart` was generated.") return local_scope["df_agg"] tools = [ describe_schema, generate_plot_code, enhance_plot_code, ] # --- LLM Setup --- llm = ChatGoogleGenerativeAI( model="gemini-1.5-flash", temperature=0.5, max_tokens=None, timeout=None, max_retries=2, ) llm = llm.bind_tools(tools) llm_analysis = ChatGoogleGenerativeAI( model="gemini-1.5-flash", temperature=0.5, max_tokens=None, timeout=None, max_retries=2, ) llm_plot = ChatGoogleGenerativeAI( model="gemini-2.0-flash", temperature=0.5, max_tokens=None, timeout=None, max_retries=2, ) # --- LangGraph State Setup --- class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] assigned_tools: Optional[List[str]] # List of tools assigned to the agent table_schema: Optional[str] # Schema of the DataFrame, assume only one table plots: List[dict] # List of generated plots sys_msg = SystemMessage(content=""" You are a helpful assistant named Terloka Bro who works for creating plots. you can run tools such as `describe_schema` to understand the dataframe schema, and `generate_plot_code` to generate Python code that creates a plot using the Altair library. Please do `describe_schema` first then `generate_plot_code` to create a plot, do not call those two function at the same time. No need to say if the chart cannot be displayed, because it already handled in the application. You already have access to a DataFrame called `df` """) def assistant(state: AgentState) -> AgentState: schema_output = describe_schema.invoke(df) res = llm.invoke([sys_msg] + [HumanMessage(content="show your scheme")] + [AIMessage(content=schema_output)] + [ToolMessage(content=schema_output, name="describe_schema", id=str(uuid.uuid4()), tool_call_id=str(uuid.uuid4()))] + state["messages"]) state["messages"].append(res) assigned_tools = [] if isinstance(res, AIMessage): if res.tool_calls: for tool_call in res.tool_calls: assigned_tools.append(tool_call) return { "messages": state["messages"], "assigned_tools": assigned_tools, "table_schema": state.get("table_schema", []), "plots": state.get("plots", []) } sys_msg_analysis = SystemMessage(content=""" You are given an aggregated `df_agg` dataframe and `instruction`. Your are required to analyze the finding base on the given data. """) def assistant_analysis(plot_code,plot_instruction): df_agg_temp = generate_df_agg_from_code(plot_code) df_agg_result = df_agg_temp.to_dict(orient='list') prompt_analysis = f""" You are given aggregation data result: ``` {df_agg_result} ``` By given analysis requirement : ``` {plot_instruction} ``` The expect output: - Only provide insight and findings base on the instruction and result - Do NOT give suggest plot code - Do NOT explain the technical of the chart information """ res = llm_analysis.invoke([sys_msg_analysis] + [HumanMessage(content=prompt_analysis)]) analysis_str = res.content return analysis_str def clean_runned_tools(state: AgentState, tool_name: str) -> AgentState: """Clean the runned tools from the state""" if state["assigned_tools"]: removed_list = state["assigned_tools"].copy() for tool_call in state["assigned_tools"]: if tool_call.get('name') == tool_name: removed_list.remove(tool_call) break state["assigned_tools"] = removed_list return state def do_describe_chema(state: AgentState) -> AgentState: """Perform the describe schema using the assigned tool""" if state["assigned_tools"]: for tool_call in state["assigned_tools"]: if tool_call.get('name') == "describe_schema": tool_res = describe_schema.invoke(tool_call['args']) # Call the tool with the arguments state["table_schema"] = tool_res tool_message = ToolMessage( content=str(tool_res), # Convert the result to string id =str(uuid.uuid4()), # Generate a unique ID for the tool message name=tool_call['name'], # Use the tool name from the tool call tool_call_id=tool_call['id'] # Use the tool call ID for tracking ) state["messages"].append(tool_message) break """ delete the runned tool call from the state """ state = clean_runned_tools(state, "describe_schema") return state def do_generate_plot_code(state: AgentState) -> AgentState: """Perform the plot generation using the assigned tool""" if state["assigned_tools"]: for tool_call in state["assigned_tools"]: if tool_call.get('name') == "generate_plot_code": tool_res = generate_plot_code.invoke(tool_call['args']) # Call the tool with the arguments if "plots" not in state: state["plots"] = [] state["plots"].append(tool_res) tool_message = ToolMessage( content=str(tool_res['code']), # Convert the result to string, but only the chart id =str(uuid.uuid4()), # Generate a unique ID for the tool message name=tool_call['name'], # Use the tool name from the tool call tool_call_id=tool_call['id'] # Use the tool call ID for tracking ) state["messages"].append(tool_message) break """ delete the runned tool call from the state """ state = clean_runned_tools(state, "generate_plot_code") return state def do_enhance_plot_code(state: AgentState) -> AgentState: """Perform the plot generation using the assigned tool""" if state["assigned_tools"]: for tool_call in state["assigned_tools"]: if tool_call.get('name') == "enhance_plot_code": tool_res = enhance_plot_code.invoke(tool_call['args']) # Call the tool with the arguments if "plots" not in state: state["plots"] = [] state["plots"].append(tool_res) tool_message = ToolMessage( content=str(tool_res['code']), # Convert the result to string, but only the chart id =str(uuid.uuid4()), # Generate a unique ID for the tool message name=tool_call['name'], # Use the tool name from the tool call tool_call_id=tool_call['id'] # Use the tool call ID for tracking ) state["messages"].append(tool_message) break """ delete the runned tool call from the state """ state = clean_runned_tools(state, "enhance_plot_code") return state def route_to_tool(state: AgentState) -> str: """Determine the next step based on assigned tools""" if state["assigned_tools"]: for tool_call in state["assigned_tools"]: if tool_call.get('name') == "describe_schema": return "describe_schema" elif tool_call.get('name') == "generate_plot_code": return "generate_plot_code" elif tool_call.get('name') == "enhance_plot_code": return "enhance_plot_code" return "no_tool_required" def route_from_tool(state: AgentState) -> str: """Determine the next step based on assigned tools""" if state["assigned_tools"]: for tool_call in state["assigned_tools"]: if tool_call.get('name') == "generate_plot_code": return "generate_plot_code" return "assistant" def build_graph(): builder = StateGraph(AgentState) builder.add_node("Assistant", assistant) builder.add_node("Describe Schema", do_describe_chema) builder.add_node("Generate Plot", do_generate_plot_code) builder.add_node("Enhance Plot", do_enhance_plot_code) edges_to_tool = { "describe_schema": "Describe Schema", "generate_plot_code": "Generate Plot", "enhance_plot_code": "Enhance Plot", "no_tool_required": END, } edges_from_tool = { "generate_plot_code": "Generate Plot", "assistant": "Assistant", } builder.add_edge(START, "Assistant") builder.add_conditional_edges("Assistant", route_to_tool, edges_to_tool) builder.add_conditional_edges("Describe Schema", route_from_tool, edges_from_tool) builder.add_conditional_edges("Generate Plot", route_from_tool, edges_from_tool) builder.add_conditional_edges("Enhance Plot", route_from_tool, edges_from_tool) builder.add_edge("Assistant", END) memory = InMemorySaver() return builder.compile(checkpointer=memory) react_graph = build_graph() config = {"configurable": {"thread_id": 123, "session": 100}} # --- Gradio UI --- def respond(message, chat_history): chat_history = [] res = react_graph.invoke( {"messages": [HumanMessage(content=message)]} , config=config) for msg in res["messages"]: msg.pretty_print() if isinstance(msg, HumanMessage): chat_history.append({"role": "user", "content": msg.content}) if isinstance(msg, AIMessage): ai_response = msg.content chat_history.append({"role": "assistant", "content": ai_response}) if isinstance(msg, ToolMessage): if msg.name == "generate_plot_code": plot_result = generate_plot_from_code(msg.content) chat_history.append({"role": "assistant", "content": gr.Plot(plot_result)}) chat_history.append({"role": "assistant", "content": res["plots"][-1].get("interpretation", " ")}) if msg.name == "enhance_plot_code": plot_result = generate_plot_from_code(msg.content) chat_history.append({"role": "assistant", "content": gr.Plot(plot_result)}) time.sleep(1) return "", chat_history my_theme = gr.Theme.from_hub("NoCrypt/miku") def to_snake_case(name): return name.lower().replace(' ', '_').replace('-', '_') def get_info_df(df): info_df = pd.DataFrame({ "column": df.columns, "non_null_count": df.notnull().sum().values, "dtype": df.dtypes.astype(str).values }) return info_df def summarize_nulls(df): null_summary = df.isnull().sum().reset_index() null_summary.columns = ['column', 'null_count'] null_summary['percent'] = (null_summary['null_count'] / len(df)) * 100 return null_summary[null_summary['null_count'] > 0] def summarize_duplicates(df): return pd.DataFrame({ "duplicated_rows": [df.duplicated().sum()], "total_rows": [len(df)], "percent_duplicated": [100 * df.duplicated().sum() / len(df)] }) def load_example_dataset(name): global df try: if name == "iris": df = pd.read_csv("https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv") elif name == "titanic": df = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/refs/heads/master/titanic.csv") elif name == "superstore": df = pd.read_excel("https://public.tableau.com/app/sample-data/sample_-_superstore.xls") else: raise ValueError("Unknown dataset name.") df.columns = [col.lower().replace(" ", "_") for col in df.columns] null_summary = summarize_nulls(df) dup_summary = summarize_duplicates(df) return ( gr.update(visible=True), # Show main tabs gr.update(visible=False), # Hide warning gr.update(visible=False), # Hide iris button gr.update(visible=False), # Hide titanic button gr.update(visible=False), # Hide superstore button gr.update(visible=False), # Hide upload button df.describe().reset_index(), get_info_df(df), df.head(), null_summary, dup_summary ) except Exception as e: raise gr.Error(f"Failed to load dataset: {e}") def handle_upload(file): global df if file is None or file.name == "": return ( gr.update(visible=False), # Hide main tabs gr.update(visible=True), # Show warning pd.DataFrame(), "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame() ) try: df = pd.read_csv(file) if file.name.endswith(".csv") else pd.read_excel(file) except Exception as e: raise gr.Error(f"Failed to read the file: {e}") df.columns = [to_snake_case(col) for col in df.columns] df = df null_summary = summarize_nulls(df) dup_summary = summarize_duplicates(df) # Rebuild the graph and reset the config global react_graph react_graph = build_graph() # Rebuild graph to reset state global config config = {"configurable": {"thread_id": str(uuid.uuid4()), "session": str(uuid.uuid4())}} return ( gr.update(visible=True), # Show main tabs gr.update(visible=False), # Hide warning gr.update(visible=False), # Hide iris button gr.update(visible=False), # Hide titanic button gr.update(visible=False), # Hide superstore button gr.update(visible=True), # Hide upload button df.describe().reset_index(), get_info_df(df), df.head(), null_summary, dup_summary ) def refresh_graph(): global react_graph react_graph = build_graph() # Rebuild graph to reset state global config config = {"configurable": {"thread_id": str(uuid.uuid4()), "session": str(uuid.uuid4())}} # Layout with gr.Blocks(theme=my_theme) as demo: demo.load(refresh_graph, inputs=None, outputs=None) gr.HTML("""
Your gateway to smarter decisions through travel data.