|
import os |
|
import gradio as gr |
|
import time |
|
import uuid |
|
from typing import List, TypedDict, Annotated, Optional |
|
from gradio.themes.base import Base |
|
import pandas as pd |
|
import altair as alt |
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ToolMessage |
|
from langchain_core.tools import tool |
|
from langgraph.checkpoint.memory import InMemorySaver |
|
|
|
from langgraph.graph.message import add_messages |
|
from langgraph.graph import START, END, StateGraph |
|
|
|
|
|
df = pd.DataFrame() |
|
|
|
|
|
@tool |
|
def describe_schema() -> str: |
|
""" |
|
Describe the dataframe schema so you will have context how to processing it. |
|
Do this before generating any plot if you are not sure about the columns, |
|
can skip this if you already know about the columns and data types. |
|
By knowing the schema, you can better understand how to instruct the plot creation. |
|
""" |
|
return str(df.dtypes) |
|
|
|
@tool |
|
def generate_plot_code(plot_instruction: str) -> dict: |
|
""" |
|
Given a plot_instruction not the direct Python code, |
|
generate Python code that: |
|
1. Performs aggregation/transformation on `df` (store in `df_agg`) |
|
2. Generates a Altair plot from `df_agg` (store in `fig`) |
|
|
|
Args: |
|
plot_instruction (str): A description of the plot to generate, e.g. "Bar chart of total revenue by region". |
|
|
|
Returns: |
|
dict: A dictionary containing: |
|
- `plot_instruction`: The original plot instruction. |
|
- `code`: The generated Python code as a string. |
|
- `chart`: The Altair chart object. |
|
- `df_agg`: The aggregated DataFrame used for the plot. |
|
""" |
|
|
|
promt_generate_plot_code = """ |
|
You are a Python assistant. A pandas DataFrame `df` is available. |
|
|
|
Your task: |
|
1. Perform any necessary data processing or aggregation based on this request: "{plot_instruction}" |
|
- Store the final df_agg in a variable called `df_agg`. |
|
- When grouping data, always use `.reset_index()` after aggregation so the group keys remain columns in the df_agg. |
|
2. Create a Altair plot from `df_agg` |
|
- Only use the Altair library. |
|
- Assign the chart to a variable named `chart`. |
|
- Do NOT include explanations, comments, or markdown (like ```python). |
|
- Use the existing DataFrame `df` directly. |
|
- Just return executable Python code. |
|
|
|
Rules: |
|
- Do NOT create fake/sample data. |
|
- Use only the real `df`. |
|
- must create variable `df_agg` for the aggregated DataFrame. |
|
- must create variable `chart` for the Altair chart. |
|
- always show title and tooltip in the chart. |
|
- No print statements or explanation โ just code. |
|
- Be flexible interpreting column names: |
|
- If the plot_instruction uses a partial or common term (e.g. "customer"), find the best matching column(s) in schema (like "customer_name"). |
|
- Normalize and expand synonyms or abbreviations to match columns. |
|
- If multiple columns match, pick the most relevant one. |
|
|
|
Example result: |
|
import altair as alt |
|
df_agg = df.groupby('region')['sales'].sum().reset_index().sort_values('sales', ascending=False) |
|
chart = alt.Chart(df_agg).mark_bar().encode( |
|
x='region:N', |
|
y='sales:Q', |
|
color=alt.Color('region:N', scale=alt.Scale(scheme='tableau10')), |
|
tooltip=['region', 'sales'] |
|
).properties( |
|
title='Top Sales per Region' |
|
).transform_calculate( |
|
text='datum.sales' |
|
).mark_bar( |
|
cornerRadiusTopLeft=3, cornerRadiusTopRight=3 |
|
) |
|
""" |
|
|
|
promt_generate_plot_code = promt_generate_plot_code.format(plot_instruction=plot_instruction) |
|
|
|
try: |
|
response = llm_plot.invoke([HumanMessage(content=promt_generate_plot_code)]) |
|
code = response.content.strip() |
|
|
|
|
|
if code.startswith("```"): |
|
lines = code.split("\n") |
|
if lines[0].startswith("```"): |
|
lines = lines[1:] |
|
if lines[-1].startswith("```"): |
|
lines = lines[:-1] |
|
code = "\n".join(lines).strip() |
|
|
|
interpretation = assistant_analysis(code,plot_instruction) |
|
return { |
|
"plot_instruction": plot_instruction, |
|
"code": code, |
|
"interpretation" : interpretation, |
|
} |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to generate plot: {e}") |
|
|
|
@tool |
|
def enhance_plot_code(previous_code: str, plot_instruction: str) -> dict: |
|
""" |
|
Given a previous code and plot_instruction not the direct Python code, |
|
enhance Python code for graph that: |
|
1. Performs aggregation/transformation on `df` (store in `df_agg`) |
|
2. Generates a Altair plot from `df_agg` (store in `fig`) |
|
3. Enhances the previous code based on the new plot_instruction |
|
|
|
Args: |
|
plot_instruction (str): A description of the plot to generate, e.g. "Bar chart of total revenue by region". |
|
|
|
Returns: |
|
dict: A dictionary containing: |
|
- `plot_instruction`: The original plot instruction. |
|
- `code`: The generated Python code as a string. |
|
|
|
By running this tool, you are assume already show the plot to user, so do not say you cannot display the plot. |
|
|
|
""" |
|
|
|
prompt_enhance_plot_code = """ |
|
You are a Python assistant. A pandas DataFrame `df` is available. |
|
|
|
You know the previous code that already generated a plot, |
|
"{previous_code}" |
|
|
|
Your task: |
|
Enhance previous code based on this request: |
|
"{plot_instruction}" |
|
|
|
Rules: |
|
- Do NOT create fake/sample data. |
|
- Use only the real `df`. |
|
- must create variable `df_agg` for the aggregated DataFrame. |
|
- must create variable `chart` for the Altair chart. |
|
- always show title and tooltip in the chart. |
|
- No print statements or explanation โ just code. |
|
- Be flexible interpreting column names: |
|
- If the plot_instruction uses a partial or common term (e.g. "customer"), find the best matching column(s) in schema (like "customer_name"). |
|
- Normalize and expand synonyms or abbreviations to match columns. |
|
- If multiple columns match, pick the most relevant one. |
|
|
|
Example result: |
|
import altair as alt |
|
df_agg = df.groupby('region')['sales'].sum().reset_index().sort_values('sales', ascending=False) |
|
chart = alt.Chart(df_agg).mark_bar().encode( |
|
x='region:N', |
|
y='sales:Q', |
|
color=alt.Color('region:N', scale=alt.Scale(scheme='tableau10')), |
|
tooltip=['region', 'sales'] |
|
).properties( |
|
title='Top Sales per Region' |
|
).transform_calculate( |
|
text='datum.sales' |
|
).mark_bar( |
|
cornerRadiusTopLeft=3, cornerRadiusTopRight=3 |
|
) |
|
""" |
|
|
|
|
|
prompt_enhance_plot_code = prompt_enhance_plot_code.format(previous_code = previous_code, plot_instruction=plot_instruction) |
|
|
|
try: |
|
response = llm_plot.invoke([HumanMessage(content=prompt_enhance_plot_code)]) |
|
code = response.content.strip() |
|
|
|
|
|
if code.startswith("```"): |
|
lines = code.split("\n") |
|
if lines[0].startswith("```"): |
|
lines = lines[1:] |
|
if lines[-1].startswith("```"): |
|
lines = lines[:-1] |
|
code = "\n".join(lines).strip() |
|
|
|
return { |
|
"plot_instruction": plot_instruction, |
|
"code": code, |
|
"interpretation":" " |
|
} |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to generate plot: {e}") |
|
|
|
def generate_plot_from_code(code: str): |
|
local_scope = {"df": df, "alt": alt} |
|
exec(code, {}, local_scope) |
|
|
|
if "chart" not in local_scope: |
|
raise ValueError("No valid `chart` was generated.") |
|
return local_scope["chart"] |
|
|
|
def generate_df_agg_from_code(code: str): |
|
local_scope = {"df": df, "alt": alt} |
|
exec(code, {}, local_scope) |
|
|
|
if "chart" not in local_scope: |
|
raise ValueError("No valid `chart` was generated.") |
|
return local_scope["df_agg"] |
|
|
|
tools = [ |
|
describe_schema, |
|
generate_plot_code, |
|
enhance_plot_code, |
|
] |
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash", |
|
temperature=0.5, |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
) |
|
llm = llm.bind_tools(tools) |
|
|
|
llm_analysis = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash", |
|
temperature=0.5, |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
) |
|
|
|
llm_plot = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash", |
|
temperature=0.5, |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
) |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list[AnyMessage], add_messages] |
|
assigned_tools: Optional[List[str]] |
|
table_schema: Optional[str] |
|
plots: List[dict] |
|
|
|
sys_msg = SystemMessage(content=""" |
|
You are a helpful assistant named Terloka Bro who works for creating plots. |
|
you can run tools such as `describe_schema` to understand the dataframe schema, |
|
and `generate_plot_code` to generate Python code that creates a plot using the Altair library. |
|
Please do `describe_schema` first then `generate_plot_code` to create a plot, do not call those two function at the same time. |
|
No need to say if the chart cannot be displayed, because it already handled in the application. |
|
You already have access to a DataFrame called `df` |
|
""") |
|
|
|
|
|
def assistant(state: AgentState) -> AgentState: |
|
|
|
schema_output = describe_schema.invoke(df) |
|
res = llm.invoke([sys_msg] + [HumanMessage(content="show your scheme")] + [AIMessage(content=schema_output)] + [ToolMessage(content=schema_output, name="describe_schema", id=str(uuid.uuid4()), tool_call_id=str(uuid.uuid4()))] + state["messages"]) |
|
|
|
state["messages"].append(res) |
|
assigned_tools = [] |
|
if isinstance(res, AIMessage): |
|
if res.tool_calls: |
|
for tool_call in res.tool_calls: |
|
assigned_tools.append(tool_call) |
|
return { |
|
"messages": state["messages"], |
|
"assigned_tools": assigned_tools, |
|
"table_schema": state.get("table_schema", []), |
|
"plots": state.get("plots", []) |
|
} |
|
|
|
|
|
sys_msg_analysis = SystemMessage(content=""" |
|
You are given an aggregated `df_agg` dataframe and `instruction`. Your are required to analyze the finding base on the given data. |
|
""") |
|
def assistant_analysis(plot_code,plot_instruction): |
|
|
|
df_agg_temp = generate_df_agg_from_code(plot_code) |
|
df_agg_result = df_agg_temp.to_dict(orient='list') |
|
|
|
prompt_analysis = f""" |
|
You are given aggregation data result: |
|
``` |
|
{df_agg_result} |
|
``` |
|
By given analysis requirement : |
|
``` |
|
{plot_instruction} |
|
``` |
|
The expect output: |
|
- Only provide insight and findings base on the instruction and result |
|
- Do NOT give suggest plot code |
|
- Do NOT explain the technical of the chart information |
|
""" |
|
|
|
res = llm_analysis.invoke([sys_msg_analysis] + [HumanMessage(content=prompt_analysis)]) |
|
analysis_str = res.content |
|
|
|
return analysis_str |
|
|
|
def clean_runned_tools(state: AgentState, tool_name: str) -> AgentState: |
|
"""Clean the runned tools from the state""" |
|
if state["assigned_tools"]: |
|
removed_list = state["assigned_tools"].copy() |
|
for tool_call in state["assigned_tools"]: |
|
if tool_call.get('name') == tool_name: |
|
removed_list.remove(tool_call) |
|
break |
|
state["assigned_tools"] = removed_list |
|
return state |
|
|
|
def do_describe_chema(state: AgentState) -> AgentState: |
|
"""Perform the describe schema using the assigned tool""" |
|
if state["assigned_tools"]: |
|
for tool_call in state["assigned_tools"]: |
|
if tool_call.get('name') == "describe_schema": |
|
tool_res = describe_schema.invoke(tool_call['args']) |
|
state["table_schema"] = tool_res |
|
tool_message = ToolMessage( |
|
content=str(tool_res), |
|
id =str(uuid.uuid4()), |
|
name=tool_call['name'], |
|
tool_call_id=tool_call['id'] |
|
) |
|
state["messages"].append(tool_message) |
|
break |
|
""" delete the runned tool call from the state """ |
|
state = clean_runned_tools(state, "describe_schema") |
|
return state |
|
|
|
def do_generate_plot_code(state: AgentState) -> AgentState: |
|
"""Perform the plot generation using the assigned tool""" |
|
if state["assigned_tools"]: |
|
for tool_call in state["assigned_tools"]: |
|
if tool_call.get('name') == "generate_plot_code": |
|
tool_res = generate_plot_code.invoke(tool_call['args']) |
|
if "plots" not in state: |
|
state["plots"] = [] |
|
state["plots"].append(tool_res) |
|
|
|
tool_message = ToolMessage( |
|
content=str(tool_res['code']), |
|
id =str(uuid.uuid4()), |
|
name=tool_call['name'], |
|
tool_call_id=tool_call['id'] |
|
) |
|
state["messages"].append(tool_message) |
|
break |
|
""" delete the runned tool call from the state """ |
|
state = clean_runned_tools(state, "generate_plot_code") |
|
return state |
|
|
|
def do_enhance_plot_code(state: AgentState) -> AgentState: |
|
"""Perform the plot generation using the assigned tool""" |
|
if state["assigned_tools"]: |
|
for tool_call in state["assigned_tools"]: |
|
if tool_call.get('name') == "enhance_plot_code": |
|
tool_res = enhance_plot_code.invoke(tool_call['args']) |
|
if "plots" not in state: |
|
state["plots"] = [] |
|
state["plots"].append(tool_res) |
|
|
|
tool_message = ToolMessage( |
|
content=str(tool_res['code']), |
|
id =str(uuid.uuid4()), |
|
name=tool_call['name'], |
|
tool_call_id=tool_call['id'] |
|
) |
|
state["messages"].append(tool_message) |
|
break |
|
""" delete the runned tool call from the state """ |
|
state = clean_runned_tools(state, "enhance_plot_code") |
|
return state |
|
|
|
def route_to_tool(state: AgentState) -> str: |
|
"""Determine the next step based on assigned tools""" |
|
if state["assigned_tools"]: |
|
for tool_call in state["assigned_tools"]: |
|
if tool_call.get('name') == "describe_schema": |
|
return "describe_schema" |
|
elif tool_call.get('name') == "generate_plot_code": |
|
return "generate_plot_code" |
|
elif tool_call.get('name') == "enhance_plot_code": |
|
return "enhance_plot_code" |
|
return "no_tool_required" |
|
|
|
def route_from_tool(state: AgentState) -> str: |
|
"""Determine the next step based on assigned tools""" |
|
if state["assigned_tools"]: |
|
for tool_call in state["assigned_tools"]: |
|
if tool_call.get('name') == "generate_plot_code": |
|
return "generate_plot_code" |
|
return "assistant" |
|
|
|
|
|
def build_graph(): |
|
builder = StateGraph(AgentState) |
|
builder.add_node("Assistant", assistant) |
|
builder.add_node("Describe Schema", do_describe_chema) |
|
builder.add_node("Generate Plot", do_generate_plot_code) |
|
builder.add_node("Enhance Plot", do_enhance_plot_code) |
|
|
|
edges_to_tool = { |
|
"describe_schema": "Describe Schema", |
|
"generate_plot_code": "Generate Plot", |
|
"enhance_plot_code": "Enhance Plot", |
|
"no_tool_required": END, |
|
} |
|
|
|
edges_from_tool = { |
|
"generate_plot_code": "Generate Plot", |
|
"assistant": "Assistant", |
|
} |
|
|
|
builder.add_edge(START, "Assistant") |
|
builder.add_conditional_edges("Assistant", route_to_tool, edges_to_tool) |
|
builder.add_conditional_edges("Describe Schema", route_from_tool, edges_from_tool) |
|
builder.add_conditional_edges("Generate Plot", route_from_tool, edges_from_tool) |
|
builder.add_conditional_edges("Enhance Plot", route_from_tool, edges_from_tool) |
|
builder.add_edge("Assistant", END) |
|
|
|
memory = InMemorySaver() |
|
return builder.compile(checkpointer=memory) |
|
|
|
react_graph = build_graph() |
|
config = {"configurable": {"thread_id": 123, "session": 100}} |
|
|
|
|
|
def to_snake_case(name): |
|
return name.lower().replace(' ', '_').replace('-', '_') |
|
|
|
def get_info_df(df): |
|
info_df = pd.DataFrame({ |
|
"column": df.columns, |
|
"non_null_count": df.notnull().sum().values, |
|
"dtype": df.dtypes.astype(str).values |
|
}) |
|
return info_df |
|
|
|
def summarize_nulls(df): |
|
null_summary = df.isnull().sum().reset_index() |
|
null_summary.columns = ['column', 'null_count'] |
|
null_summary['percent'] = (null_summary['null_count'] / len(df)) * 100 |
|
return null_summary[null_summary['null_count'] > 0] |
|
|
|
def summarize_duplicates(df): |
|
return pd.DataFrame({ |
|
"duplicated_rows": [df.duplicated().sum()], |
|
"total_rows": [len(df)], |
|
"percent_duplicated": [100 * df.duplicated().sum() / len(df)] |
|
}) |
|
|
|
def load_example_dataset(name): |
|
global df |
|
try: |
|
if name == "iris": |
|
df = pd.read_csv("https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv") |
|
elif name == "titanic": |
|
df = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/refs/heads/master/titanic.csv") |
|
elif name == "superstore": |
|
df = pd.read_excel("https://public.tableau.com/app/sample-data/sample_-_superstore.xls") |
|
else: |
|
raise ValueError("Unknown dataset name.") |
|
|
|
df.columns = [col.lower().replace(" ", "_") for col in df.columns] |
|
null_summary = summarize_nulls(df) |
|
dup_summary = summarize_duplicates(df) |
|
|
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
df.describe().reset_index(), |
|
get_info_df(df), |
|
df.head(), |
|
null_summary, |
|
dup_summary |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load dataset: {e}") |
|
|
|
def handle_upload(file): |
|
global df |
|
if file is None or file.name == "": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
pd.DataFrame(), "", pd.DataFrame(), pd.DataFrame(), pd.DataFrame() |
|
) |
|
|
|
try: |
|
df = pd.read_csv(file) if file.name.endswith(".csv") else pd.read_excel(file) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to read the file: {e}") |
|
|
|
df.columns = [to_snake_case(col) for col in df.columns] |
|
df = df |
|
|
|
null_summary = summarize_nulls(df) |
|
dup_summary = summarize_duplicates(df) |
|
|
|
|
|
global react_graph |
|
react_graph = build_graph() |
|
global config |
|
config = {"configurable": {"thread_id": str(uuid.uuid4()), "session": str(uuid.uuid4())}} |
|
|
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
df.describe().reset_index(), |
|
get_info_df(df), |
|
df.head(), |
|
null_summary, |
|
dup_summary |
|
) |
|
|
|
def refresh_graph(): |
|
global react_graph |
|
react_graph = build_graph() |
|
global config |
|
config = {"configurable": {"thread_id": str(uuid.uuid4()), "session": str(uuid.uuid4())}} |
|
|
|
def respond(message, chat_history): |
|
chat_history = [] |
|
res = react_graph.invoke( |
|
{"messages": [HumanMessage(content=message)]} |
|
, config=config) |
|
for msg in res["messages"]: |
|
msg.pretty_print() |
|
if isinstance(msg, HumanMessage): |
|
chat_history.append({"role": "user", "content": msg.content}) |
|
|
|
if isinstance(msg, AIMessage): |
|
ai_response = msg.content |
|
chat_history.append({"role": "assistant", "content": ai_response}) |
|
|
|
if isinstance(msg, ToolMessage): |
|
if msg.name == "generate_plot_code": |
|
plot_result = generate_plot_from_code(msg.content) |
|
chat_history.append({"role": "assistant", "content": gr.Plot(plot_result)}) |
|
chat_history.append({"role": "assistant", "content": res["plots"][-1].get('interpretation')}) |
|
|
|
if msg.name == "enhance_plot_code": |
|
plot_result = generate_plot_from_code(msg.content) |
|
chat_history.append({"role": "assistant", "content": gr.Plot(plot_result)}) |
|
|
|
time.sleep(1) |
|
return "", chat_history |
|
|
|
|
|
my_theme = gr.Theme.from_hub("NoCrypt/miku") |
|
with gr.Blocks(theme=my_theme) as demo: |
|
demo.load(refresh_graph, inputs=None, outputs=None) |
|
|
|
gr.HTML(""" |
|
<style> |
|
body, .container, h1, h2, h3, p, span { |
|
font-family: "IBM Plex Sans"; |
|
} |
|
|
|
#instruction blockquote { |
|
margin: 12px auto 0 auto; |
|
padding: 12px 16px; |
|
|
|
border-radius: 6px; |
|
|
|
font-size: 14px; |
|
max-width: 7000px; |
|
} |
|
|
|
#chatbot_hint { |
|
margin: 12px auto 0 auto; |
|
padding: 12px 16px; |
|
|
|
border-radius: 6px; |
|
|
|
font-size: 14px; |
|
max-width: 7000px; |
|
} |
|
|
|
@keyframes fadeInTitle { |
|
0% { |
|
opacity: 0; |
|
transform: translateY(-10px); |
|
} |
|
100% { |
|
opacity: 1; |
|
transform: translateY(0); |
|
} |
|
} |
|
|
|
.container { |
|
|
|
padding: 24px; |
|
border-radius: 16px; |
|
box-shadow: 0 2px 30px rgba(42, 86, 198, 0.12); |
|
text-align: center; |
|
transition: box-shadow 0.3s ease; |
|
margin-bottom: 12px; |
|
} |
|
|
|
.subtitle { |
|
font-size: 16px; |
|
margin-top: -6px; |
|
} |
|
</style> |
|
|
|
<div class="container"> |
|
<h1> |
|
<span style="font-size: 30px;">๐ฏ</span> |
|
<span class="title-gradient">AI Chat to Visual</span> |
|
</h1> |
|
<p class="subtitle">Your gateway to smarter decisions through travel data.</p> |
|
</div> |
|
|
|
""") |
|
|
|
instruction_box = gr.Markdown( |
|
"> Upload a file to get started. Supported formats: `.csv`, `.xls`, `.xlsx`", |
|
elem_id="instruction" |
|
) |
|
warning_box = gr.Markdown("โ ๏ธ **You can't proceed without uploading your files first**", visible=True) |
|
upload_btn = gr.File(file_types=[".csv", ".xls", ".xlsx"], label="๐ Upload File") |
|
example_box = gr.Markdown("### Or use an example dataset:") |
|
with gr.Row(): |
|
iris_btn = gr.Button("๐ธ Load Iris") |
|
titanic_btn = gr.Button("๐ข Load Titanic") |
|
superstore_btn = gr.Button("๐ช Load Superstore") |
|
|
|
|
|
with gr.Tabs(visible=False) as main_tabs: |
|
with gr.Tab("๐ค ChatBot for Viz"): |
|
gr.Markdown( |
|
"๐ Want to understand your data first? Go to the Data Exploration tab first!", |
|
elem_id="chatbot_hint" |
|
) |
|
chatbot = gr.Chatbot(type="messages", label="Data Chatbot", elem_id="chatbot") |
|
msg = gr.Textbox(label="",elem_id="chat_input", container=False, placeholder="Ask me anything about your data...") |
|
msg.submit(respond, [msg, chatbot], [msg, chatbot]) |
|
|
|
with gr.Tab("๐ Data Exploration"): |
|
with gr.Column(): |
|
with gr.Accordion("๐งฎ Data Description", open=True): |
|
describe_output = gr.DataFrame() |
|
with gr.Accordion("๐ Data Info", open=True): |
|
info_output = gr.DataFrame() |
|
with gr.Accordion("๐๏ธ Preview Data", open=False): |
|
head_output = gr.DataFrame() |
|
with gr.Accordion("๐งผ Null Detection", open=False): |
|
null_output = gr.DataFrame() |
|
with gr.Accordion("๐ Duplicate Check", open=False): |
|
dup_output = gr.DataFrame() |
|
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
gr.Markdown("๐ ๏ธ Built with โค๏ธ by **Terloka Bros**", elem_id="footer") |
|
|
|
upload_btn.change( |
|
fn=handle_upload, |
|
inputs=upload_btn, |
|
outputs=[ |
|
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box, |
|
describe_output, info_output, |
|
head_output, null_output, |
|
dup_output |
|
] |
|
) |
|
iris_btn.click( |
|
fn=lambda: load_example_dataset("iris"), |
|
outputs=[ |
|
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box, |
|
describe_output, info_output, |
|
head_output, null_output, |
|
dup_output |
|
] |
|
) |
|
|
|
titanic_btn.click( |
|
fn=lambda: load_example_dataset("titanic"), |
|
outputs=[ |
|
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box, |
|
describe_output, info_output, |
|
head_output, null_output, |
|
dup_output |
|
] |
|
) |
|
|
|
superstore_btn.click( |
|
fn=lambda: load_example_dataset("superstore"), |
|
outputs=[ |
|
main_tabs, warning_box, iris_btn, titanic_btn, superstore_btn,upload_btn,instruction_box,example_box, |
|
describe_output, info_output, |
|
head_output, null_output, |
|
dup_output |
|
] |
|
) |
|
|
|
demo.launch() |
|
|