Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- src/question-answer/agents.py +221 -0
- src/question-answer/graph.py +169 -0
- src/question-answer/prompts.py +150 -0
- src/question-answer/states.py +13 -0
- src/question-answer/tools.py +264 -0
- src/question-answer/utils.py +177 -0
src/question-answer/agents.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from prompts import coder_prompt, fixer_prompt, analysis_prompt
|
2 |
+
from states import State
|
3 |
+
from langchain_openai import ChatOpenAI
|
4 |
+
from langchain_core.messages import SystemMessage, AIMessage
|
5 |
+
from tools import run_code
|
6 |
+
from utils import create_markdown_report, save_markdown_report
|
7 |
+
|
8 |
+
def coder_agent(state: State) -> State:
|
9 |
+
"""
|
10 |
+
Creates the cleaning code in Python to be executed in a sandbox.
|
11 |
+
"""
|
12 |
+
print("-----------------------------------------------------------------------------------")
|
13 |
+
print("Creating the analysis code...")
|
14 |
+
print("-----------------------------------------------------------------------------------")
|
15 |
+
|
16 |
+
# Create the LLM
|
17 |
+
llm = ChatOpenAI(model="gpt-4.1", temperature=0)
|
18 |
+
|
19 |
+
# Get the dataset path
|
20 |
+
dataset_path = state['dataset_path']
|
21 |
+
|
22 |
+
# Get the messages
|
23 |
+
messages = state['messages']
|
24 |
+
|
25 |
+
# Extract the most recent user question
|
26 |
+
question = None
|
27 |
+
for message in reversed(messages):
|
28 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'human':
|
29 |
+
question = message.content
|
30 |
+
break
|
31 |
+
|
32 |
+
if not question:
|
33 |
+
question = "Analyze the dataset"
|
34 |
+
|
35 |
+
# Create the system prompt with conversation context
|
36 |
+
system_prompt = coder_prompt(dataset_path, question)
|
37 |
+
|
38 |
+
# Build conversation context for the LLM
|
39 |
+
conversation_messages = []
|
40 |
+
|
41 |
+
# Add system prompt
|
42 |
+
conversation_messages.append(SystemMessage(content=system_prompt, additional_kwargs={"agent": "system", "node_type": "generation"}))
|
43 |
+
|
44 |
+
# Add recent conversation history (last 10 messages to keep context manageable)
|
45 |
+
recent_messages = messages[-10:] if len(messages) > 10 else messages
|
46 |
+
for msg in recent_messages:
|
47 |
+
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('agent') in ['human', 'analysis_agent']:
|
48 |
+
conversation_messages.append(msg)
|
49 |
+
|
50 |
+
# Invoke the LLM with conversation context
|
51 |
+
code = llm.invoke(conversation_messages).content
|
52 |
+
messages.append(AIMessage(content=code, additional_kwargs={"agent": "coder_agent", "node_type": "generation"}))
|
53 |
+
|
54 |
+
return {
|
55 |
+
"messages": messages,
|
56 |
+
"dataset_path": dataset_path
|
57 |
+
}
|
58 |
+
|
59 |
+
def runner_agent(state: State) -> State:
|
60 |
+
"""
|
61 |
+
Runs the analysis code in a sandbox.
|
62 |
+
"""
|
63 |
+
print("-----------------------------------------------------------------------------------")
|
64 |
+
print("Running the analysis code...")
|
65 |
+
print("-----------------------------------------------------------------------------------")
|
66 |
+
|
67 |
+
code = state['messages'][-1].content
|
68 |
+
dataset_path = state['dataset_path']
|
69 |
+
|
70 |
+
# save the code to a file in the output folder
|
71 |
+
import os
|
72 |
+
os.makedirs("output", exist_ok=True)
|
73 |
+
with open("output/analysis_code.py", "w") as f:
|
74 |
+
f.write(code)
|
75 |
+
|
76 |
+
result = run_code(code, dataset_path)
|
77 |
+
|
78 |
+
# Get the messages and add the result
|
79 |
+
messages = state['messages']
|
80 |
+
messages.append(AIMessage(content=f"Code execution result: {result['execution']}",
|
81 |
+
additional_kwargs={"agent": "runner_agent", "node_type": "generation"}))
|
82 |
+
|
83 |
+
# Track generated charts as message objects
|
84 |
+
charts = result.get('charts', [])
|
85 |
+
chart_messages = [AIMessage(content=chart, additional_kwargs={"agent": "runner_agent", "node_type": "chart"}) for chart in charts]
|
86 |
+
|
87 |
+
return {
|
88 |
+
"messages": messages,
|
89 |
+
"codes": state.get('codes', []) + [code],
|
90 |
+
"charts": chart_messages
|
91 |
+
}
|
92 |
+
|
93 |
+
def fixer_agent(state: State) -> State:
|
94 |
+
"""
|
95 |
+
Fixes the analysis code.
|
96 |
+
"""
|
97 |
+
print("-----------------------------------------------------------------------------------")
|
98 |
+
print("Fixing the analysis code...")
|
99 |
+
print("-----------------------------------------------------------------------------------")
|
100 |
+
|
101 |
+
# Extract the last human message (question)
|
102 |
+
question = None
|
103 |
+
for message in reversed(state['messages']):
|
104 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'human':
|
105 |
+
question = message.content
|
106 |
+
break
|
107 |
+
|
108 |
+
# Extract the last coder_agent message (code)
|
109 |
+
code = None
|
110 |
+
for message in reversed(state['messages']):
|
111 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'coder_agent':
|
112 |
+
code = message.content
|
113 |
+
break
|
114 |
+
|
115 |
+
# Extract the last runner_agent message (error)
|
116 |
+
error = None
|
117 |
+
for message in reversed(state['messages']):
|
118 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'runner_agent':
|
119 |
+
error = message.content
|
120 |
+
break
|
121 |
+
|
122 |
+
# Get the dataset path
|
123 |
+
dataset_path = state['dataset_path']
|
124 |
+
|
125 |
+
# Create the system prompt
|
126 |
+
system_prompt = fixer_prompt(code, error, question, dataset_path)
|
127 |
+
|
128 |
+
# Get the messages
|
129 |
+
messages = state['messages']
|
130 |
+
|
131 |
+
# Add the system prompt to the messages
|
132 |
+
messages.append(SystemMessage(content=system_prompt, additional_kwargs={"agent": "system", "node_type": "fixing"}))
|
133 |
+
|
134 |
+
# Create the LLM and invoke it to fix the code
|
135 |
+
llm = ChatOpenAI(model="gpt-4.1", temperature=0)
|
136 |
+
fixed_code = llm.invoke(system_prompt).content
|
137 |
+
messages.append(AIMessage(content=fixed_code, additional_kwargs={"agent": "fixer_agent", "node_type": "fixing"}))
|
138 |
+
|
139 |
+
return {
|
140 |
+
"messages": messages,
|
141 |
+
"codes": state.get('codes', []) + [fixed_code]
|
142 |
+
}
|
143 |
+
|
144 |
+
def analysis_agent(state: State) -> State:
|
145 |
+
"""
|
146 |
+
Analyzes the question, the result of the execution, and the charts to answer the question.
|
147 |
+
"""
|
148 |
+
print("-----------------------------------------------------------------------------------")
|
149 |
+
print("Analyzing the question, the result of the execution, and the charts to answer the question...")
|
150 |
+
print("-----------------------------------------------------------------------------------")
|
151 |
+
|
152 |
+
# Get the messages
|
153 |
+
messages = state['messages']
|
154 |
+
|
155 |
+
# last human message
|
156 |
+
question = None
|
157 |
+
for message in reversed(messages):
|
158 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'human':
|
159 |
+
question = message.content
|
160 |
+
break
|
161 |
+
|
162 |
+
# Get the dataset path
|
163 |
+
dataset_path = state['dataset_path']
|
164 |
+
|
165 |
+
# Get the execution result
|
166 |
+
execution_result = None
|
167 |
+
for message in reversed(messages):
|
168 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'runner_agent':
|
169 |
+
execution_result = message.content
|
170 |
+
break
|
171 |
+
|
172 |
+
# Get the charts from state and ensure they are strings
|
173 |
+
charts = state.get('charts', [])
|
174 |
+
|
175 |
+
# Convert any message objects to strings and filter out duplicates
|
176 |
+
chart_paths = []
|
177 |
+
seen_charts = set()
|
178 |
+
for chart in charts:
|
179 |
+
if hasattr(chart, 'content'):
|
180 |
+
chart_path = chart.content
|
181 |
+
else:
|
182 |
+
chart_path = str(chart)
|
183 |
+
|
184 |
+
# Only add unique chart paths
|
185 |
+
if chart_path not in seen_charts:
|
186 |
+
chart_paths.append(chart_path)
|
187 |
+
seen_charts.add(chart_path)
|
188 |
+
|
189 |
+
# Create the system prompt
|
190 |
+
system_prompt = analysis_prompt(question, dataset_path, execution_result, chart_paths)
|
191 |
+
|
192 |
+
# Build conversation context for the LLM
|
193 |
+
conversation_messages = []
|
194 |
+
|
195 |
+
# Add system prompt
|
196 |
+
conversation_messages.append(SystemMessage(content=system_prompt, additional_kwargs={"agent": "system", "node_type": "analysis"}))
|
197 |
+
|
198 |
+
# Add recent conversation history (last 10 messages to keep context manageable)
|
199 |
+
recent_messages = messages[-10:] if len(messages) > 10 else messages
|
200 |
+
for msg in recent_messages:
|
201 |
+
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('agent') in ['human', 'analysis_agent']:
|
202 |
+
conversation_messages.append(msg)
|
203 |
+
|
204 |
+
# Create the LLM and invoke it to analyze the question, the result of the execution, and the charts to answer the question.
|
205 |
+
llm = ChatOpenAI(model="gpt-4.1", temperature=0)
|
206 |
+
analysis = llm.invoke(conversation_messages).content
|
207 |
+
messages.append(AIMessage(content=analysis, additional_kwargs={"agent": "analysis_agent", "node_type": "analysis"}))
|
208 |
+
|
209 |
+
# Create a markdown report
|
210 |
+
report_content = create_markdown_report(question, analysis, chart_paths, execution_result)
|
211 |
+
|
212 |
+
# Save the report to a file
|
213 |
+
report_filename = save_markdown_report(report_content)
|
214 |
+
|
215 |
+
# Report filename is returned in the state, no need to add to messages
|
216 |
+
|
217 |
+
return {
|
218 |
+
"messages": messages,
|
219 |
+
"analysis": analysis,
|
220 |
+
"report": report_filename
|
221 |
+
}
|
src/question-answer/graph.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
import os
|
5 |
+
from typing import TypedDict
|
6 |
+
from langgraph.graph import START, END
|
7 |
+
from langgraph.graph import StateGraph
|
8 |
+
from langchain_core.messages import HumanMessage
|
9 |
+
|
10 |
+
from states import State
|
11 |
+
from agents import coder_agent, runner_agent, fixer_agent, analysis_agent
|
12 |
+
|
13 |
+
class Context(TypedDict):
|
14 |
+
"""Context parameters for the agent.
|
15 |
+
|
16 |
+
Set these when creating assistants OR when invoking the graph.
|
17 |
+
See: https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/
|
18 |
+
"""
|
19 |
+
|
20 |
+
my_configurable_param: str
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
state = {
|
25 |
+
"messages": [],
|
26 |
+
"dataset_path": "/Users/beyzaerdogan/Desktop/ai-analyst/cereal.csv"
|
27 |
+
}
|
28 |
+
|
29 |
+
def build_graph():
|
30 |
+
graph_builder = StateGraph(State)
|
31 |
+
graph_builder.add_node("coder_agent", coder_agent)
|
32 |
+
graph_builder.add_node("runner_agent", runner_agent)
|
33 |
+
graph_builder.add_node("fixer_agent", fixer_agent)
|
34 |
+
graph_builder.add_node("analysis_agent", analysis_agent)
|
35 |
+
|
36 |
+
graph_builder.add_edge(START, "coder_agent")
|
37 |
+
graph_builder.add_edge("coder_agent", "runner_agent")
|
38 |
+
|
39 |
+
def should_fix_code(state: State) -> str:
|
40 |
+
"""Determine if we need to fix the code based on execution result."""
|
41 |
+
messages = state.get("messages", [])
|
42 |
+
if not messages:
|
43 |
+
return "success"
|
44 |
+
|
45 |
+
# Count fix attempts to prevent infinite loops
|
46 |
+
fix_attempts = 0
|
47 |
+
for message in messages:
|
48 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'fixer_agent':
|
49 |
+
fix_attempts += 1
|
50 |
+
|
51 |
+
# Limit to 3 fix attempts to prevent quota exceeded errors
|
52 |
+
if fix_attempts >= 3:
|
53 |
+
return "success"
|
54 |
+
|
55 |
+
# Get the last runner_agent message to check for errors
|
56 |
+
for message in reversed(messages):
|
57 |
+
if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'runner_agent':
|
58 |
+
content = message.content
|
59 |
+
# Check if the execution failed
|
60 |
+
if "Execution failed" in content or "error" in content.lower() or "failed" in content.lower():
|
61 |
+
return "error"
|
62 |
+
break
|
63 |
+
return "success"
|
64 |
+
|
65 |
+
graph_builder.add_conditional_edges(
|
66 |
+
"runner_agent",
|
67 |
+
should_fix_code,
|
68 |
+
{
|
69 |
+
"error": "fixer_agent",
|
70 |
+
"success": "analysis_agent"
|
71 |
+
}
|
72 |
+
)
|
73 |
+
|
74 |
+
graph_builder.add_edge("fixer_agent", "runner_agent")
|
75 |
+
graph_builder.add_edge("analysis_agent", END)
|
76 |
+
|
77 |
+
return graph_builder.compile()
|
78 |
+
|
79 |
+
def chat_interface():
|
80 |
+
"""Interactive chat interface for data analysis."""
|
81 |
+
graph = build_graph()
|
82 |
+
|
83 |
+
# Initialize state
|
84 |
+
state = {
|
85 |
+
"messages": [],
|
86 |
+
"dataset_path": "/Users/beyzaerdogan/Desktop/ai-analyst/cereal.csv",
|
87 |
+
"charts": [],
|
88 |
+
"report": "",
|
89 |
+
"codes": []
|
90 |
+
}
|
91 |
+
|
92 |
+
print("🤖 AI Data Analyst Chat")
|
93 |
+
print("=" * 50)
|
94 |
+
print("Ask me anything about your dataset! Type 'quit' or 'exit' to end the conversation.")
|
95 |
+
print("=" * 50)
|
96 |
+
|
97 |
+
while True:
|
98 |
+
try:
|
99 |
+
# Get user input
|
100 |
+
user_input = input("\n👤 You: ").strip()
|
101 |
+
|
102 |
+
# Check for exit commands
|
103 |
+
if user_input.lower() in ['quit', 'exit', 'bye', 'goodbye']:
|
104 |
+
print("\n🤖 AI: Goodbye! Thanks for chatting with me.")
|
105 |
+
break
|
106 |
+
|
107 |
+
if not user_input:
|
108 |
+
print("🤖 AI: Please enter a question or message.")
|
109 |
+
continue
|
110 |
+
|
111 |
+
# Add user message to state
|
112 |
+
user_message = HumanMessage(
|
113 |
+
content=user_input,
|
114 |
+
additional_kwargs={"agent": "human", "node_type": "question"}
|
115 |
+
)
|
116 |
+
state["messages"].append(user_message)
|
117 |
+
|
118 |
+
print("\n🤖 AI: Let me analyze that for you...")
|
119 |
+
|
120 |
+
# Run the graph with current state
|
121 |
+
state = graph.invoke(state)
|
122 |
+
|
123 |
+
# Extract and display the analysis response
|
124 |
+
analysis_response = None
|
125 |
+
for message in reversed(state["messages"]):
|
126 |
+
if (hasattr(message, 'additional_kwargs') and
|
127 |
+
message.additional_kwargs.get('agent') == 'analysis_agent' and
|
128 |
+
message.additional_kwargs.get('node_type') == 'analysis'):
|
129 |
+
analysis_response = message.content
|
130 |
+
break
|
131 |
+
|
132 |
+
if analysis_response:
|
133 |
+
print(f"\n🤖 AI: {analysis_response}")
|
134 |
+
else:
|
135 |
+
print("\n🤖 AI: I've processed your request. Check the output folder for results.")
|
136 |
+
|
137 |
+
# Show any generated reports
|
138 |
+
for message in state["messages"]:
|
139 |
+
if (hasattr(message, 'additional_kwargs') and
|
140 |
+
message.additional_kwargs.get('agent') == 'analysis_agent' and
|
141 |
+
message.additional_kwargs.get('node_type') == 'report'):
|
142 |
+
print(f"📊 Report: {message.content}")
|
143 |
+
|
144 |
+
except KeyboardInterrupt:
|
145 |
+
print("\n\n🤖 AI: Goodbye! Thanks for chatting with me.")
|
146 |
+
break
|
147 |
+
except Exception as e:
|
148 |
+
print(f"\n🤖 AI: Sorry, I encountered an error: {str(e)}")
|
149 |
+
print("Please try again with a different question.")
|
150 |
+
|
151 |
+
def main():
|
152 |
+
"""Main function for single question mode (backward compatibility)."""
|
153 |
+
graph = build_graph()
|
154 |
+
state = {
|
155 |
+
"messages": [
|
156 |
+
HumanMessage(content="What are the top 5 cereals by their protein amount?",
|
157 |
+
additional_kwargs={"agent": "human", "node_type": "question"})
|
158 |
+
],
|
159 |
+
"dataset_path": "/Users/beyzaerdogan/Desktop/ai-analyst/cereal.csv",
|
160 |
+
"charts": [],
|
161 |
+
"report": ""
|
162 |
+
}
|
163 |
+
|
164 |
+
state = graph.invoke(state)
|
165 |
+
print(state)
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
# Run chat interface by default
|
169 |
+
chat_interface()
|
src/question-answer/prompts.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
try:
|
3 |
+
from .utils import get_dataset_info
|
4 |
+
except ImportError:
|
5 |
+
from utils import get_dataset_info
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import glob
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
def coder_prompt(dataset_path: str, question: str) -> str:
|
14 |
+
|
15 |
+
"""
|
16 |
+
System prompt for the data analyst.
|
17 |
+
"""
|
18 |
+
dataset_info = get_dataset_info(dataset_path)
|
19 |
+
return f"""
|
20 |
+
You are a senior data analyst.
|
21 |
+
You have access to a pandas dataframe `df` that will be available in the sandbox environment.
|
22 |
+
|
23 |
+
Here is the dataset information:
|
24 |
+
{dataset_info}
|
25 |
+
|
26 |
+
USER QUESTION: {question}
|
27 |
+
|
28 |
+
Write Python code to answer this specific question both visually and statistically.
|
29 |
+
The code will be executed in a secure sandbox environment where the dataset is available as a CSV file.
|
30 |
+
|
31 |
+
IMPORTANT GUIDELINES:
|
32 |
+
1. First load the dataset: df = pd.read_csv('tmp/dataset.csv')
|
33 |
+
2. Only use built-in Python libraries, pandas, matplotlib, and seaborn
|
34 |
+
3. Write clear, well-commented code
|
35 |
+
4. Handle potential errors gracefully
|
36 |
+
5. Return meaningful results that directly answer the question
|
37 |
+
6. ALWAYS create visualizations when they would help answer the question
|
38 |
+
7. Use proper statistical tests and analysis to answer the question
|
39 |
+
|
40 |
+
VISUALIZATION REQUIREMENTS:
|
41 |
+
- ALWAYS use matplotlib.pyplot for plotting
|
42 |
+
- ALWAYS save plots as files using plt.savefig() before plt.show()
|
43 |
+
- Set proper figure size: plt.figure(figsize=(10, 6))
|
44 |
+
- Add titles and labels: plt.title(), plt.xlabel(), plt.ylabel()
|
45 |
+
- Use plt.tight_layout() for better spacing
|
46 |
+
- Use pastel colors for the plots
|
47 |
+
- Save each plot with a unique filename: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight')
|
48 |
+
- Call plt.show() after saving
|
49 |
+
- Example: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight'); plt.show()
|
50 |
+
|
51 |
+
RETURN FORMAT:
|
52 |
+
- Return ONLY the Python code without any markdown formatting
|
53 |
+
- Do NOT include ```python or ``` markers
|
54 |
+
- Do NOT include any explanatory text or comments
|
55 |
+
- Start directly with import statements
|
56 |
+
"""
|
57 |
+
|
58 |
+
def fixer_prompt(code: str, error: str, question: str, dataset_path: str) -> str:
|
59 |
+
"""
|
60 |
+
System prompt for the analysis code fixing agent.
|
61 |
+
"""
|
62 |
+
|
63 |
+
dataset_info = get_dataset_info(dataset_path)
|
64 |
+
|
65 |
+
return f"""
|
66 |
+
You are a senior data analyst.
|
67 |
+
You have access to a pandas dataframe `df` that will be available in the sandbox environment.
|
68 |
+
|
69 |
+
Here is the dataset information:
|
70 |
+
{dataset_info}
|
71 |
+
|
72 |
+
USER QUESTION: {question}
|
73 |
+
|
74 |
+
Here is the code that failed:
|
75 |
+
{code}
|
76 |
+
|
77 |
+
Here is the error message:
|
78 |
+
{error}
|
79 |
+
|
80 |
+
Your task is to fix the code to resolve the error.
|
81 |
+
|
82 |
+
VISUALIZATION REQUIREMENTS (if the code includes plots):
|
83 |
+
- ALWAYS save plots as files using plt.savefig() before plt.show()
|
84 |
+
- Save each plot with a unique filename: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight')
|
85 |
+
- Call plt.show() after saving
|
86 |
+
- Example: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight'); plt.show()
|
87 |
+
|
88 |
+
RETURN FORMAT:
|
89 |
+
- Return ONLY the fixed Python code without any markdown formatting
|
90 |
+
- Do NOT include ```python or ``` markers
|
91 |
+
- Do NOT include any explanatory text or comments
|
92 |
+
- Start directly with import statements
|
93 |
+
"""
|
94 |
+
|
95 |
+
def analysis_prompt(question: str, dataset_path: str, execution_result: str, charts: list = None) -> str:
|
96 |
+
"""
|
97 |
+
System prompt for the analysis agent.
|
98 |
+
"""
|
99 |
+
dataset_info = get_dataset_info(dataset_path)
|
100 |
+
|
101 |
+
# Use provided charts or find them in output directory
|
102 |
+
if charts is None:
|
103 |
+
charts = []
|
104 |
+
# Get the absolute path to the output directory
|
105 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
106 |
+
output_dir = os.path.join(current_dir, "output")
|
107 |
+
|
108 |
+
if os.path.exists(output_dir):
|
109 |
+
# Find all image files in the output directory
|
110 |
+
chart_patterns = ['*.png', '*.jpg', '*.jpeg', '*.svg', '*.pdf']
|
111 |
+
for pattern in chart_patterns:
|
112 |
+
chart_files = glob.glob(os.path.join(output_dir, pattern))
|
113 |
+
charts.extend(chart_files)
|
114 |
+
|
115 |
+
# Sort charts by creation time (newest first)
|
116 |
+
charts.sort(key=os.path.getctime, reverse=True)
|
117 |
+
|
118 |
+
print(f"Found {len(charts)} charts: {charts}")
|
119 |
+
|
120 |
+
# Create chart information for the prompt
|
121 |
+
chart_info = ""
|
122 |
+
if charts:
|
123 |
+
chart_info = f"\n\nCHARTS GENERATED:\n"
|
124 |
+
for i, chart in enumerate(charts, 1):
|
125 |
+
chart_info += f"Chart {i}: {chart}\n"
|
126 |
+
chart_info += "\nThese charts contain visual representations of the data analysis. Please refer to them when providing your analysis."
|
127 |
+
else:
|
128 |
+
chart_info = "\n\nNo charts were generated for this analysis."
|
129 |
+
|
130 |
+
return f"""
|
131 |
+
You are a senior data analyst.
|
132 |
+
|
133 |
+
Here is the dataset information:
|
134 |
+
{dataset_info}
|
135 |
+
|
136 |
+
Here is the question that was asked:
|
137 |
+
{question}
|
138 |
+
|
139 |
+
Here is the result of the execution:
|
140 |
+
{execution_result}
|
141 |
+
{chart_info}
|
142 |
+
|
143 |
+
Your task is to analyze the question, the result of the execution, and the charts to answer the question comprehensively.
|
144 |
+
|
145 |
+
RETURN FORMAT:
|
146 |
+
- Return ONLY the analysis in a single string
|
147 |
+
- Do NOT include any other text or comments
|
148 |
+
- Start directly with the analysis
|
149 |
+
- If charts were generated, reference them in your analysis
|
150 |
+
"""
|
src/question-answer/states.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Type definitions for the LangGraph project."""
|
2 |
+
|
3 |
+
from typing import TypedDict, Annotated
|
4 |
+
from langchain_core.messages import AnyMessage
|
5 |
+
from langgraph.graph.message import add_messages
|
6 |
+
|
7 |
+
class State(TypedDict):
|
8 |
+
|
9 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
10 |
+
dataset_path: str
|
11 |
+
codes: Annotated[list[str], add_messages]
|
12 |
+
charts: Annotated[list[str], add_messages]
|
13 |
+
report: str
|
src/question-answer/tools.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
import base64
|
3 |
+
import os
|
4 |
+
from daytona import Daytona
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
def clean_code(code: str) -> str:
|
8 |
+
"""Clean the code - remove any file paths or non-Python content"""
|
9 |
+
lines = code.split('\n')
|
10 |
+
python_start = -1
|
11 |
+
|
12 |
+
# Look for Python code starting patterns
|
13 |
+
for i, line in enumerate(lines):
|
14 |
+
stripped = line.strip()
|
15 |
+
# Skip markdown code blocks
|
16 |
+
if stripped.startswith('```python') or stripped.startswith('```'):
|
17 |
+
continue
|
18 |
+
# Look for actual Python code
|
19 |
+
if (stripped.startswith('import ') or
|
20 |
+
stripped.startswith('from ') or
|
21 |
+
stripped.startswith('# ') or
|
22 |
+
stripped.startswith('"""') or
|
23 |
+
stripped.startswith("'''") or
|
24 |
+
(stripped and not '/' in stripped and not stripped.endswith('"'))):
|
25 |
+
python_start = i
|
26 |
+
break
|
27 |
+
|
28 |
+
if python_start > 0:
|
29 |
+
cleaned_code = '\n'.join(lines[python_start:])
|
30 |
+
print(f"Cleaned code by removing {python_start} lines from the beginning")
|
31 |
+
else:
|
32 |
+
cleaned_code = code
|
33 |
+
|
34 |
+
# Remove any remaining markdown markers
|
35 |
+
cleaned_code = cleaned_code.replace('```python', '').replace('```', '')
|
36 |
+
|
37 |
+
return cleaned_code
|
38 |
+
|
39 |
+
def print_logs(result) -> None:
|
40 |
+
"""Print the logs of the execution"""
|
41 |
+
if hasattr(result, 'stdout') and result.stdout:
|
42 |
+
print("STDOUT:")
|
43 |
+
print(result.stdout)
|
44 |
+
if hasattr(result, 'stderr') and result.stderr:
|
45 |
+
print("STDERR:")
|
46 |
+
print(result.stderr)
|
47 |
+
|
48 |
+
def cleanup_sandboxes():
|
49 |
+
"""Clean up all sandboxes to free resources"""
|
50 |
+
try:
|
51 |
+
daytona = Daytona()
|
52 |
+
# Try to get existing sandbox and delete it
|
53 |
+
try:
|
54 |
+
existing_sandbox = daytona.get_current_sandbox()
|
55 |
+
existing_sandbox.delete()
|
56 |
+
print(f"Cleaned up existing sandbox: {existing_sandbox.id}")
|
57 |
+
except:
|
58 |
+
print("No existing sandbox to clean up")
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Warning: Could not clean up sandboxes: {e}")
|
61 |
+
|
62 |
+
def run_code(code: str, dataset_path: str) -> dict:
|
63 |
+
"""Run code in a Daytona sandbox"""
|
64 |
+
|
65 |
+
cleaned_code = clean_code(code)
|
66 |
+
|
67 |
+
# initialize daytona client
|
68 |
+
daytona = Daytona()
|
69 |
+
|
70 |
+
# try to get existing sandbox
|
71 |
+
try:
|
72 |
+
sandbox = daytona.get_current_sandbox()
|
73 |
+
print("Using existing sandbox")
|
74 |
+
except:
|
75 |
+
try:
|
76 |
+
sandbox = daytona.create()
|
77 |
+
print("Created new sandbox")
|
78 |
+
except Exception as e:
|
79 |
+
if "CPU quota exceeded" in str(e) or "disk quota exceeded" in str(e).lower():
|
80 |
+
print("Quota exceeded, cleaning up and trying again...")
|
81 |
+
cleanup_sandboxes()
|
82 |
+
try:
|
83 |
+
sandbox = daytona.create()
|
84 |
+
print("Created new sandbox after cleanup")
|
85 |
+
except Exception as e2:
|
86 |
+
print(f"Still failed after cleanup: {e2}")
|
87 |
+
raise e2
|
88 |
+
else:
|
89 |
+
raise e
|
90 |
+
|
91 |
+
# Upload the dataset to the sandbox using file system operations
|
92 |
+
try:
|
93 |
+
# Upload the original dataset to the sandbox
|
94 |
+
sandbox_datapath = f"tmp/{os.path.basename(dataset_path)}"
|
95 |
+
sandbox.fs.upload_file(dataset_path, sandbox_datapath)
|
96 |
+
print(f"Uploaded {dataset_path} to {sandbox_datapath}")
|
97 |
+
|
98 |
+
# Replace the original file path in the code with the sandbox path
|
99 |
+
cleaned_code = cleaned_code.replace(dataset_path, sandbox_datapath)
|
100 |
+
print(f"Updated code to use sandbox path: {sandbox_datapath}")
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
print(f"WARNING: Could not upload {dataset_path} to sandbox: {e}")
|
104 |
+
# If the file doesn't exist locally, continue without uploading
|
105 |
+
if "does not exist" in str(e) or "not found" in str(e).lower():
|
106 |
+
print(f"File {dataset_path} does not exist locally, continuing without upload")
|
107 |
+
sandbox_datapath = dataset_path # Use original path as fallback
|
108 |
+
else:
|
109 |
+
raise e
|
110 |
+
|
111 |
+
#########################################################
|
112 |
+
################# Running the code ######################
|
113 |
+
#########################################################
|
114 |
+
|
115 |
+
try:
|
116 |
+
# Install only essential dependencies to speed up execution
|
117 |
+
print("Installing essential dependencies...")
|
118 |
+
install_deps_code = """
|
119 |
+
import subprocess
|
120 |
+
import sys
|
121 |
+
|
122 |
+
# Install only essential packages
|
123 |
+
packages = ['matplotlib', 'pandas', 'numpy']
|
124 |
+
for package in packages:
|
125 |
+
try:
|
126 |
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '--quiet'])
|
127 |
+
print(f"Installed {package}")
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Failed to install {package}: {e}")
|
130 |
+
"""
|
131 |
+
|
132 |
+
try:
|
133 |
+
deps_result = sandbox.process.code_run(install_deps_code)
|
134 |
+
print("Dependencies installation completed")
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Warning: Could not install dependencies: {e}")
|
137 |
+
|
138 |
+
# Run the code in the sandbox
|
139 |
+
print("-----------------------------------------------------------------------------------")
|
140 |
+
print('Running the analysis code in the sandbox....')
|
141 |
+
|
142 |
+
result = sandbox.process.code_run(cleaned_code)
|
143 |
+
print('Code execution finished!')
|
144 |
+
|
145 |
+
# Check for execution errors
|
146 |
+
if result.exit_code != 0:
|
147 |
+
print(f"EXECUTION ERROR: {result.result}")
|
148 |
+
return {
|
149 |
+
"success": False,
|
150 |
+
"execution": f"Execution failed with error: {result.result}"
|
151 |
+
}
|
152 |
+
|
153 |
+
except Exception as e:
|
154 |
+
print(f"Error running code: {e}")
|
155 |
+
return {
|
156 |
+
"success": False,
|
157 |
+
"execution": str(e)
|
158 |
+
}
|
159 |
+
|
160 |
+
print("-----------------------------------------------------------------------------------")
|
161 |
+
print("Checking for files in the sandbox...")
|
162 |
+
print("-----------------------------------------------------------------------------------")
|
163 |
+
|
164 |
+
#########################################################
|
165 |
+
############# Post-execution file checking ##############
|
166 |
+
#########################################################
|
167 |
+
|
168 |
+
# Check what files were created after code execution
|
169 |
+
try:
|
170 |
+
post_debug_code = """
|
171 |
+
import os
|
172 |
+
import glob
|
173 |
+
print("\\n=== FILES AFTER CODE EXECUTION ===")
|
174 |
+
print(f"Current working directory: {os.getcwd()}")
|
175 |
+
print("Files in current directory:")
|
176 |
+
for f in os.listdir('.'):
|
177 |
+
print(f" {f}")
|
178 |
+
"""
|
179 |
+
|
180 |
+
post_debug_result = sandbox.process.code_run(post_debug_code)
|
181 |
+
print("Post-execution debug info:", post_debug_result.result)
|
182 |
+
except Exception as e:
|
183 |
+
print(f"Could not check post-execution files: {e}")
|
184 |
+
|
185 |
+
#########################################################
|
186 |
+
############# Checking for charts in the sandbox ########
|
187 |
+
#########################################################
|
188 |
+
|
189 |
+
# Ensure output directory exists
|
190 |
+
os.makedirs("output", exist_ok=True)
|
191 |
+
|
192 |
+
# Check for plots - look for saved plot files in the sandbox
|
193 |
+
charts_count = 0
|
194 |
+
|
195 |
+
# Look for common plot file patterns in both current directory and /tmp/
|
196 |
+
search_directories = ['.', '/tmp']
|
197 |
+
plot_patterns = ['*.png', '*.jpg', '*.jpeg', '*.svg', '*.pdf']
|
198 |
+
|
199 |
+
for search_dir in search_directories:
|
200 |
+
print(f"Searching for charts in {search_dir}...")
|
201 |
+
for pattern in plot_patterns:
|
202 |
+
try:
|
203 |
+
# List files matching the pattern in the specific directory
|
204 |
+
list_cmd = f"import glob; import os; files = glob.glob('{search_dir}/{pattern}'); print('\\n'.join(files))"
|
205 |
+
plot_files_result = sandbox.process.code_run(list_cmd)
|
206 |
+
|
207 |
+
if plot_files_result.result.strip():
|
208 |
+
plot_files = plot_files_result.result.strip().split('\n')
|
209 |
+
for i, plot_file in enumerate(plot_files):
|
210 |
+
try:
|
211 |
+
# Create filename for the chart
|
212 |
+
chart_filename = f"chart_{charts_count + 1}.{plot_file.strip().split('.')[-1]}"
|
213 |
+
|
214 |
+
# Download the plot file from sandbox
|
215 |
+
sandbox.fs.download_file(plot_file.strip(), f"output/{chart_filename}")
|
216 |
+
|
217 |
+
print(f"Downloaded chart: output/{chart_filename}")
|
218 |
+
charts_count += 1
|
219 |
+
|
220 |
+
except Exception as e:
|
221 |
+
print(f"Error processing chart {plot_file}: {e}")
|
222 |
+
except Exception as e:
|
223 |
+
print(f"Error searching for {pattern} files in {search_dir}: {e}")
|
224 |
+
|
225 |
+
if charts_count == 0:
|
226 |
+
print("WARNING: No charts were downloaded.")
|
227 |
+
# Let's also check what files actually exist in the sandbox
|
228 |
+
try:
|
229 |
+
debug_cmd = """
|
230 |
+
import os
|
231 |
+
import glob
|
232 |
+
print("\\n=== DEBUGGING CHART DETECTION ===")
|
233 |
+
print("Current working directory:", os.getcwd())
|
234 |
+
print("Files in current directory:")
|
235 |
+
for f in os.listdir('.'):
|
236 |
+
print(f" {f}")
|
237 |
+
print("\\nAll image files found:")
|
238 |
+
for pattern in ['*.png', '*.jpg', '*.jpeg', '*.svg', '*.pdf']:
|
239 |
+
files = glob.glob(pattern)
|
240 |
+
if files:
|
241 |
+
print(f" {pattern}: {files}")
|
242 |
+
else:
|
243 |
+
print(f" {pattern}: No files found")
|
244 |
+
"""
|
245 |
+
debug_result = sandbox.process.code_run(debug_cmd)
|
246 |
+
print("Chart detection debug:", debug_result.result)
|
247 |
+
except Exception as e:
|
248 |
+
print(f"Could not run chart detection debug: {e}")
|
249 |
+
else:
|
250 |
+
print(f"Successfully downloaded {charts_count} charts")
|
251 |
+
|
252 |
+
# Clean up the sandbox to free disk space
|
253 |
+
try:
|
254 |
+
sandbox.delete()
|
255 |
+
print("Sandbox cleaned up to free disk space")
|
256 |
+
except Exception as e:
|
257 |
+
print(f"Warning: Could not clean up sandbox: {e}")
|
258 |
+
|
259 |
+
return {
|
260 |
+
"success": True,
|
261 |
+
"execution": result.result,
|
262 |
+
"charts": [f"output/chart_{i+1}.png" for i in range(charts_count)]
|
263 |
+
}
|
264 |
+
|
src/question-answer/utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from datetime import datetime
|
7 |
+
import re
|
8 |
+
|
9 |
+
def format_execution_results(execution_result: str) -> str:
|
10 |
+
"""
|
11 |
+
Format execution results to be more appealing and concise.
|
12 |
+
"""
|
13 |
+
if not execution_result:
|
14 |
+
return "No execution results available."
|
15 |
+
|
16 |
+
# Remove technical warnings and error messages
|
17 |
+
lines = execution_result.split('\n')
|
18 |
+
cleaned_lines = []
|
19 |
+
skip_next_lines = 0
|
20 |
+
|
21 |
+
for i, line in enumerate(lines):
|
22 |
+
# Skip lines if we're in a warning block
|
23 |
+
if skip_next_lines > 0:
|
24 |
+
skip_next_lines -= 1
|
25 |
+
continue
|
26 |
+
|
27 |
+
# Skip warning lines and their content
|
28 |
+
if any(warning in line.lower() for warning in [
|
29 |
+
'futurewarning', 'userwarning', 'deprecationwarning',
|
30 |
+
'target_code:', 'warning:', 'error:', 'passing `palette`'
|
31 |
+
]):
|
32 |
+
# Skip the warning line and the next few lines that are part of the warning
|
33 |
+
skip_next_lines = 2
|
34 |
+
continue
|
35 |
+
|
36 |
+
# Skip empty lines at the beginning
|
37 |
+
if not cleaned_lines and not line.strip():
|
38 |
+
continue
|
39 |
+
|
40 |
+
# Skip lines that are just function calls or technical details
|
41 |
+
if line.strip() in ['sns.barplot(', 'sns.barplot(']:
|
42 |
+
continue
|
43 |
+
|
44 |
+
cleaned_lines.append(line)
|
45 |
+
|
46 |
+
# Join and clean up
|
47 |
+
result = '\n'.join(cleaned_lines).strip()
|
48 |
+
|
49 |
+
# Remove the "Code execution result:" prefix if present
|
50 |
+
if result.startswith("Code execution result: "):
|
51 |
+
result = result[23:]
|
52 |
+
|
53 |
+
# If the result is too long, truncate it
|
54 |
+
if len(result) > 800:
|
55 |
+
result = result[:800] + "\n... (truncated for readability)"
|
56 |
+
|
57 |
+
return f"```\n{result}\n```"
|
58 |
+
|
59 |
+
def get_dataset_info(dataset_path: str) -> str:
|
60 |
+
"""
|
61 |
+
Get comprehensive information about the dataset.
|
62 |
+
"""
|
63 |
+
try:
|
64 |
+
# Load the dataset
|
65 |
+
file_path = Path(dataset_path)
|
66 |
+
if not file_path.exists():
|
67 |
+
return f"Error: Dataset file '{dataset_path}' not found."
|
68 |
+
|
69 |
+
# Read dataset based on file extension
|
70 |
+
if file_path.suffix.lower() == '.csv':
|
71 |
+
df = pd.read_csv(dataset_path)
|
72 |
+
elif file_path.suffix.lower() in ['.xlsx', '.xls']:
|
73 |
+
df = pd.read_excel(dataset_path)
|
74 |
+
elif file_path.suffix.lower() == '.json':
|
75 |
+
df = pd.read_json(dataset_path)
|
76 |
+
elif file_path.suffix.lower() == '.parquet':
|
77 |
+
df = pd.read_parquet(dataset_path)
|
78 |
+
else:
|
79 |
+
return f"Error: Unsupported file format '{file_path.suffix}'"
|
80 |
+
|
81 |
+
# Gather comprehensive dataset information
|
82 |
+
info = {
|
83 |
+
'file_name': file_path.name,
|
84 |
+
'file_size': f"{file_path.stat().st_size / (1024*1024):.2f} MB",
|
85 |
+
'shape': f"{df.shape[0]} rows x {df.shape[1]} columns",
|
86 |
+
'columns': list(df.columns),
|
87 |
+
'data_types': df.dtypes.to_dict(),
|
88 |
+
'missing_values': df.isnull().sum().to_dict(),
|
89 |
+
'missing_percentage': (df.isnull().sum() / len(df) * 100).round(2).to_dict(),
|
90 |
+
'duplicate_rows': df.duplicated().sum(),
|
91 |
+
'memory_usage': f"{df.memory_usage(deep=True).sum() / (1024*1024):.2f} MB",
|
92 |
+
'numeric_columns': list(df.select_dtypes(include=[np.number]).columns),
|
93 |
+
'categorical_columns': list(df.select_dtypes(include=['object', 'category']).columns),
|
94 |
+
'datetime_columns': list(df.select_dtypes(include=['datetime64']).columns)
|
95 |
+
}
|
96 |
+
|
97 |
+
# Add statistical summary for numeric columns
|
98 |
+
if info['numeric_columns']:
|
99 |
+
numeric_stats = df[info['numeric_columns']].describe().to_dict()
|
100 |
+
info['numeric_statistics'] = numeric_stats
|
101 |
+
|
102 |
+
# Add unique value counts for categorical columns (sample)
|
103 |
+
categorical_info = {}
|
104 |
+
for col in info['categorical_columns'][:5]: # Limit to first 5 categorical columns
|
105 |
+
unique_count = df[col].nunique()
|
106 |
+
categorical_info[col] = {
|
107 |
+
'unique_values': unique_count,
|
108 |
+
'sample_values': list(df[col].dropna().unique()[:10]) # First 10 unique values
|
109 |
+
}
|
110 |
+
info['categorical_info'] = categorical_info
|
111 |
+
|
112 |
+
# Identify potential data quality issues
|
113 |
+
issues = []
|
114 |
+
|
115 |
+
# Check for columns with high missing values
|
116 |
+
high_missing = [col for col, pct in info['missing_percentage'].items() if pct > 50]
|
117 |
+
if high_missing:
|
118 |
+
issues.append(f"High missing values (>50%): {high_missing}")
|
119 |
+
|
120 |
+
# Check for potential outliers in numeric columns
|
121 |
+
for col in info['numeric_columns']:
|
122 |
+
q1 = df[col].quantile(0.25)
|
123 |
+
q3 = df[col].quantile(0.75)
|
124 |
+
iqr = q3 - q1
|
125 |
+
outliers = df[(df[col] < (q1 - 1.5 * iqr)) | (df[col] > (q3 + 1.5 * iqr))][col].count()
|
126 |
+
if outliers > 0:
|
127 |
+
issues.append(f"Potential outliers in '{col}': {outliers} values")
|
128 |
+
|
129 |
+
# Check for inconsistent data types
|
130 |
+
for col in info['categorical_columns']:
|
131 |
+
if df[col].dtype == 'object':
|
132 |
+
# Check if column contains mixed numeric and string values
|
133 |
+
sample_values = df[col].dropna().astype(str).head(100)
|
134 |
+
numeric_count = sum(1 for val in sample_values if val.replace('.', '').replace('-', '').isdigit())
|
135 |
+
if 0 < numeric_count < len(sample_values):
|
136 |
+
issues.append(f"Mixed data types in '{col}': contains both numeric and text values")
|
137 |
+
|
138 |
+
info['potential_issues'] = issues
|
139 |
+
|
140 |
+
return json.dumps(info, indent=2, default=str)
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
return f"Error analyzing dataset: {str(e)}"
|
144 |
+
|
145 |
+
def create_markdown_report(question: str, analysis: str, charts: list, execution_result: str) -> str:
|
146 |
+
"""
|
147 |
+
Create a simple markdown report with analysis and key findings.
|
148 |
+
"""
|
149 |
+
report = f"""## Analysis
|
150 |
+
{analysis}
|
151 |
+
|
152 |
+
## Key Findings
|
153 |
+
{format_execution_results(execution_result)}
|
154 |
+
"""
|
155 |
+
|
156 |
+
return report
|
157 |
+
|
158 |
+
def save_markdown_report(report_content: str) -> str:
|
159 |
+
"""
|
160 |
+
Save the markdown report to a file and return the filename.
|
161 |
+
"""
|
162 |
+
# Ensure output directory exists
|
163 |
+
output_dir = "output"
|
164 |
+
os.makedirs(output_dir, exist_ok=True)
|
165 |
+
|
166 |
+
# Generate filename with timestamp
|
167 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
168 |
+
filename = f"analysis_report_{timestamp}.md"
|
169 |
+
filepath = os.path.join(output_dir, filename)
|
170 |
+
|
171 |
+
# Write the report
|
172 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
173 |
+
f.write(report_content)
|
174 |
+
|
175 |
+
print(f"Report saved: {filepath}")
|
176 |
+
return filepath
|
177 |
+
|