beyzacodeway commited on
Commit
dbd3785
·
verified ·
1 Parent(s): 2f255d3

Upload 6 files

Browse files
src/question-answer/agents.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompts import coder_prompt, fixer_prompt, analysis_prompt
2
+ from states import State
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain_core.messages import SystemMessage, AIMessage
5
+ from tools import run_code
6
+ from utils import create_markdown_report, save_markdown_report
7
+
8
+ def coder_agent(state: State) -> State:
9
+ """
10
+ Creates the cleaning code in Python to be executed in a sandbox.
11
+ """
12
+ print("-----------------------------------------------------------------------------------")
13
+ print("Creating the analysis code...")
14
+ print("-----------------------------------------------------------------------------------")
15
+
16
+ # Create the LLM
17
+ llm = ChatOpenAI(model="gpt-4.1", temperature=0)
18
+
19
+ # Get the dataset path
20
+ dataset_path = state['dataset_path']
21
+
22
+ # Get the messages
23
+ messages = state['messages']
24
+
25
+ # Extract the most recent user question
26
+ question = None
27
+ for message in reversed(messages):
28
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'human':
29
+ question = message.content
30
+ break
31
+
32
+ if not question:
33
+ question = "Analyze the dataset"
34
+
35
+ # Create the system prompt with conversation context
36
+ system_prompt = coder_prompt(dataset_path, question)
37
+
38
+ # Build conversation context for the LLM
39
+ conversation_messages = []
40
+
41
+ # Add system prompt
42
+ conversation_messages.append(SystemMessage(content=system_prompt, additional_kwargs={"agent": "system", "node_type": "generation"}))
43
+
44
+ # Add recent conversation history (last 10 messages to keep context manageable)
45
+ recent_messages = messages[-10:] if len(messages) > 10 else messages
46
+ for msg in recent_messages:
47
+ if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('agent') in ['human', 'analysis_agent']:
48
+ conversation_messages.append(msg)
49
+
50
+ # Invoke the LLM with conversation context
51
+ code = llm.invoke(conversation_messages).content
52
+ messages.append(AIMessage(content=code, additional_kwargs={"agent": "coder_agent", "node_type": "generation"}))
53
+
54
+ return {
55
+ "messages": messages,
56
+ "dataset_path": dataset_path
57
+ }
58
+
59
+ def runner_agent(state: State) -> State:
60
+ """
61
+ Runs the analysis code in a sandbox.
62
+ """
63
+ print("-----------------------------------------------------------------------------------")
64
+ print("Running the analysis code...")
65
+ print("-----------------------------------------------------------------------------------")
66
+
67
+ code = state['messages'][-1].content
68
+ dataset_path = state['dataset_path']
69
+
70
+ # save the code to a file in the output folder
71
+ import os
72
+ os.makedirs("output", exist_ok=True)
73
+ with open("output/analysis_code.py", "w") as f:
74
+ f.write(code)
75
+
76
+ result = run_code(code, dataset_path)
77
+
78
+ # Get the messages and add the result
79
+ messages = state['messages']
80
+ messages.append(AIMessage(content=f"Code execution result: {result['execution']}",
81
+ additional_kwargs={"agent": "runner_agent", "node_type": "generation"}))
82
+
83
+ # Track generated charts as message objects
84
+ charts = result.get('charts', [])
85
+ chart_messages = [AIMessage(content=chart, additional_kwargs={"agent": "runner_agent", "node_type": "chart"}) for chart in charts]
86
+
87
+ return {
88
+ "messages": messages,
89
+ "codes": state.get('codes', []) + [code],
90
+ "charts": chart_messages
91
+ }
92
+
93
+ def fixer_agent(state: State) -> State:
94
+ """
95
+ Fixes the analysis code.
96
+ """
97
+ print("-----------------------------------------------------------------------------------")
98
+ print("Fixing the analysis code...")
99
+ print("-----------------------------------------------------------------------------------")
100
+
101
+ # Extract the last human message (question)
102
+ question = None
103
+ for message in reversed(state['messages']):
104
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'human':
105
+ question = message.content
106
+ break
107
+
108
+ # Extract the last coder_agent message (code)
109
+ code = None
110
+ for message in reversed(state['messages']):
111
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'coder_agent':
112
+ code = message.content
113
+ break
114
+
115
+ # Extract the last runner_agent message (error)
116
+ error = None
117
+ for message in reversed(state['messages']):
118
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'runner_agent':
119
+ error = message.content
120
+ break
121
+
122
+ # Get the dataset path
123
+ dataset_path = state['dataset_path']
124
+
125
+ # Create the system prompt
126
+ system_prompt = fixer_prompt(code, error, question, dataset_path)
127
+
128
+ # Get the messages
129
+ messages = state['messages']
130
+
131
+ # Add the system prompt to the messages
132
+ messages.append(SystemMessage(content=system_prompt, additional_kwargs={"agent": "system", "node_type": "fixing"}))
133
+
134
+ # Create the LLM and invoke it to fix the code
135
+ llm = ChatOpenAI(model="gpt-4.1", temperature=0)
136
+ fixed_code = llm.invoke(system_prompt).content
137
+ messages.append(AIMessage(content=fixed_code, additional_kwargs={"agent": "fixer_agent", "node_type": "fixing"}))
138
+
139
+ return {
140
+ "messages": messages,
141
+ "codes": state.get('codes', []) + [fixed_code]
142
+ }
143
+
144
+ def analysis_agent(state: State) -> State:
145
+ """
146
+ Analyzes the question, the result of the execution, and the charts to answer the question.
147
+ """
148
+ print("-----------------------------------------------------------------------------------")
149
+ print("Analyzing the question, the result of the execution, and the charts to answer the question...")
150
+ print("-----------------------------------------------------------------------------------")
151
+
152
+ # Get the messages
153
+ messages = state['messages']
154
+
155
+ # last human message
156
+ question = None
157
+ for message in reversed(messages):
158
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'human':
159
+ question = message.content
160
+ break
161
+
162
+ # Get the dataset path
163
+ dataset_path = state['dataset_path']
164
+
165
+ # Get the execution result
166
+ execution_result = None
167
+ for message in reversed(messages):
168
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'runner_agent':
169
+ execution_result = message.content
170
+ break
171
+
172
+ # Get the charts from state and ensure they are strings
173
+ charts = state.get('charts', [])
174
+
175
+ # Convert any message objects to strings and filter out duplicates
176
+ chart_paths = []
177
+ seen_charts = set()
178
+ for chart in charts:
179
+ if hasattr(chart, 'content'):
180
+ chart_path = chart.content
181
+ else:
182
+ chart_path = str(chart)
183
+
184
+ # Only add unique chart paths
185
+ if chart_path not in seen_charts:
186
+ chart_paths.append(chart_path)
187
+ seen_charts.add(chart_path)
188
+
189
+ # Create the system prompt
190
+ system_prompt = analysis_prompt(question, dataset_path, execution_result, chart_paths)
191
+
192
+ # Build conversation context for the LLM
193
+ conversation_messages = []
194
+
195
+ # Add system prompt
196
+ conversation_messages.append(SystemMessage(content=system_prompt, additional_kwargs={"agent": "system", "node_type": "analysis"}))
197
+
198
+ # Add recent conversation history (last 10 messages to keep context manageable)
199
+ recent_messages = messages[-10:] if len(messages) > 10 else messages
200
+ for msg in recent_messages:
201
+ if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('agent') in ['human', 'analysis_agent']:
202
+ conversation_messages.append(msg)
203
+
204
+ # Create the LLM and invoke it to analyze the question, the result of the execution, and the charts to answer the question.
205
+ llm = ChatOpenAI(model="gpt-4.1", temperature=0)
206
+ analysis = llm.invoke(conversation_messages).content
207
+ messages.append(AIMessage(content=analysis, additional_kwargs={"agent": "analysis_agent", "node_type": "analysis"}))
208
+
209
+ # Create a markdown report
210
+ report_content = create_markdown_report(question, analysis, chart_paths, execution_result)
211
+
212
+ # Save the report to a file
213
+ report_filename = save_markdown_report(report_content)
214
+
215
+ # Report filename is returned in the state, no need to add to messages
216
+
217
+ return {
218
+ "messages": messages,
219
+ "analysis": analysis,
220
+ "report": report_filename
221
+ }
src/question-answer/graph.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from __future__ import annotations
4
+ import os
5
+ from typing import TypedDict
6
+ from langgraph.graph import START, END
7
+ from langgraph.graph import StateGraph
8
+ from langchain_core.messages import HumanMessage
9
+
10
+ from states import State
11
+ from agents import coder_agent, runner_agent, fixer_agent, analysis_agent
12
+
13
+ class Context(TypedDict):
14
+ """Context parameters for the agent.
15
+
16
+ Set these when creating assistants OR when invoking the graph.
17
+ See: https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/
18
+ """
19
+
20
+ my_configurable_param: str
21
+
22
+
23
+ if __name__ == "__main__":
24
+ state = {
25
+ "messages": [],
26
+ "dataset_path": "/Users/beyzaerdogan/Desktop/ai-analyst/cereal.csv"
27
+ }
28
+
29
+ def build_graph():
30
+ graph_builder = StateGraph(State)
31
+ graph_builder.add_node("coder_agent", coder_agent)
32
+ graph_builder.add_node("runner_agent", runner_agent)
33
+ graph_builder.add_node("fixer_agent", fixer_agent)
34
+ graph_builder.add_node("analysis_agent", analysis_agent)
35
+
36
+ graph_builder.add_edge(START, "coder_agent")
37
+ graph_builder.add_edge("coder_agent", "runner_agent")
38
+
39
+ def should_fix_code(state: State) -> str:
40
+ """Determine if we need to fix the code based on execution result."""
41
+ messages = state.get("messages", [])
42
+ if not messages:
43
+ return "success"
44
+
45
+ # Count fix attempts to prevent infinite loops
46
+ fix_attempts = 0
47
+ for message in messages:
48
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'fixer_agent':
49
+ fix_attempts += 1
50
+
51
+ # Limit to 3 fix attempts to prevent quota exceeded errors
52
+ if fix_attempts >= 3:
53
+ return "success"
54
+
55
+ # Get the last runner_agent message to check for errors
56
+ for message in reversed(messages):
57
+ if hasattr(message, 'additional_kwargs') and message.additional_kwargs.get('agent') == 'runner_agent':
58
+ content = message.content
59
+ # Check if the execution failed
60
+ if "Execution failed" in content or "error" in content.lower() or "failed" in content.lower():
61
+ return "error"
62
+ break
63
+ return "success"
64
+
65
+ graph_builder.add_conditional_edges(
66
+ "runner_agent",
67
+ should_fix_code,
68
+ {
69
+ "error": "fixer_agent",
70
+ "success": "analysis_agent"
71
+ }
72
+ )
73
+
74
+ graph_builder.add_edge("fixer_agent", "runner_agent")
75
+ graph_builder.add_edge("analysis_agent", END)
76
+
77
+ return graph_builder.compile()
78
+
79
+ def chat_interface():
80
+ """Interactive chat interface for data analysis."""
81
+ graph = build_graph()
82
+
83
+ # Initialize state
84
+ state = {
85
+ "messages": [],
86
+ "dataset_path": "/Users/beyzaerdogan/Desktop/ai-analyst/cereal.csv",
87
+ "charts": [],
88
+ "report": "",
89
+ "codes": []
90
+ }
91
+
92
+ print("🤖 AI Data Analyst Chat")
93
+ print("=" * 50)
94
+ print("Ask me anything about your dataset! Type 'quit' or 'exit' to end the conversation.")
95
+ print("=" * 50)
96
+
97
+ while True:
98
+ try:
99
+ # Get user input
100
+ user_input = input("\n👤 You: ").strip()
101
+
102
+ # Check for exit commands
103
+ if user_input.lower() in ['quit', 'exit', 'bye', 'goodbye']:
104
+ print("\n🤖 AI: Goodbye! Thanks for chatting with me.")
105
+ break
106
+
107
+ if not user_input:
108
+ print("🤖 AI: Please enter a question or message.")
109
+ continue
110
+
111
+ # Add user message to state
112
+ user_message = HumanMessage(
113
+ content=user_input,
114
+ additional_kwargs={"agent": "human", "node_type": "question"}
115
+ )
116
+ state["messages"].append(user_message)
117
+
118
+ print("\n🤖 AI: Let me analyze that for you...")
119
+
120
+ # Run the graph with current state
121
+ state = graph.invoke(state)
122
+
123
+ # Extract and display the analysis response
124
+ analysis_response = None
125
+ for message in reversed(state["messages"]):
126
+ if (hasattr(message, 'additional_kwargs') and
127
+ message.additional_kwargs.get('agent') == 'analysis_agent' and
128
+ message.additional_kwargs.get('node_type') == 'analysis'):
129
+ analysis_response = message.content
130
+ break
131
+
132
+ if analysis_response:
133
+ print(f"\n🤖 AI: {analysis_response}")
134
+ else:
135
+ print("\n🤖 AI: I've processed your request. Check the output folder for results.")
136
+
137
+ # Show any generated reports
138
+ for message in state["messages"]:
139
+ if (hasattr(message, 'additional_kwargs') and
140
+ message.additional_kwargs.get('agent') == 'analysis_agent' and
141
+ message.additional_kwargs.get('node_type') == 'report'):
142
+ print(f"📊 Report: {message.content}")
143
+
144
+ except KeyboardInterrupt:
145
+ print("\n\n🤖 AI: Goodbye! Thanks for chatting with me.")
146
+ break
147
+ except Exception as e:
148
+ print(f"\n🤖 AI: Sorry, I encountered an error: {str(e)}")
149
+ print("Please try again with a different question.")
150
+
151
+ def main():
152
+ """Main function for single question mode (backward compatibility)."""
153
+ graph = build_graph()
154
+ state = {
155
+ "messages": [
156
+ HumanMessage(content="What are the top 5 cereals by their protein amount?",
157
+ additional_kwargs={"agent": "human", "node_type": "question"})
158
+ ],
159
+ "dataset_path": "/Users/beyzaerdogan/Desktop/ai-analyst/cereal.csv",
160
+ "charts": [],
161
+ "report": ""
162
+ }
163
+
164
+ state = graph.invoke(state)
165
+ print(state)
166
+
167
+ if __name__ == "__main__":
168
+ # Run chat interface by default
169
+ chat_interface()
src/question-answer/prompts.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ try:
3
+ from .utils import get_dataset_info
4
+ except ImportError:
5
+ from utils import get_dataset_info
6
+ from dotenv import load_dotenv
7
+ import json
8
+ import os
9
+ import glob
10
+
11
+ load_dotenv()
12
+
13
+ def coder_prompt(dataset_path: str, question: str) -> str:
14
+
15
+ """
16
+ System prompt for the data analyst.
17
+ """
18
+ dataset_info = get_dataset_info(dataset_path)
19
+ return f"""
20
+ You are a senior data analyst.
21
+ You have access to a pandas dataframe `df` that will be available in the sandbox environment.
22
+
23
+ Here is the dataset information:
24
+ {dataset_info}
25
+
26
+ USER QUESTION: {question}
27
+
28
+ Write Python code to answer this specific question both visually and statistically.
29
+ The code will be executed in a secure sandbox environment where the dataset is available as a CSV file.
30
+
31
+ IMPORTANT GUIDELINES:
32
+ 1. First load the dataset: df = pd.read_csv('tmp/dataset.csv')
33
+ 2. Only use built-in Python libraries, pandas, matplotlib, and seaborn
34
+ 3. Write clear, well-commented code
35
+ 4. Handle potential errors gracefully
36
+ 5. Return meaningful results that directly answer the question
37
+ 6. ALWAYS create visualizations when they would help answer the question
38
+ 7. Use proper statistical tests and analysis to answer the question
39
+
40
+ VISUALIZATION REQUIREMENTS:
41
+ - ALWAYS use matplotlib.pyplot for plotting
42
+ - ALWAYS save plots as files using plt.savefig() before plt.show()
43
+ - Set proper figure size: plt.figure(figsize=(10, 6))
44
+ - Add titles and labels: plt.title(), plt.xlabel(), plt.ylabel()
45
+ - Use plt.tight_layout() for better spacing
46
+ - Use pastel colors for the plots
47
+ - Save each plot with a unique filename: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight')
48
+ - Call plt.show() after saving
49
+ - Example: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight'); plt.show()
50
+
51
+ RETURN FORMAT:
52
+ - Return ONLY the Python code without any markdown formatting
53
+ - Do NOT include ```python or ``` markers
54
+ - Do NOT include any explanatory text or comments
55
+ - Start directly with import statements
56
+ """
57
+
58
+ def fixer_prompt(code: str, error: str, question: str, dataset_path: str) -> str:
59
+ """
60
+ System prompt for the analysis code fixing agent.
61
+ """
62
+
63
+ dataset_info = get_dataset_info(dataset_path)
64
+
65
+ return f"""
66
+ You are a senior data analyst.
67
+ You have access to a pandas dataframe `df` that will be available in the sandbox environment.
68
+
69
+ Here is the dataset information:
70
+ {dataset_info}
71
+
72
+ USER QUESTION: {question}
73
+
74
+ Here is the code that failed:
75
+ {code}
76
+
77
+ Here is the error message:
78
+ {error}
79
+
80
+ Your task is to fix the code to resolve the error.
81
+
82
+ VISUALIZATION REQUIREMENTS (if the code includes plots):
83
+ - ALWAYS save plots as files using plt.savefig() before plt.show()
84
+ - Save each plot with a unique filename: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight')
85
+ - Call plt.show() after saving
86
+ - Example: plt.savefig('chart_1.png', dpi=300, bbox_inches='tight'); plt.show()
87
+
88
+ RETURN FORMAT:
89
+ - Return ONLY the fixed Python code without any markdown formatting
90
+ - Do NOT include ```python or ``` markers
91
+ - Do NOT include any explanatory text or comments
92
+ - Start directly with import statements
93
+ """
94
+
95
+ def analysis_prompt(question: str, dataset_path: str, execution_result: str, charts: list = None) -> str:
96
+ """
97
+ System prompt for the analysis agent.
98
+ """
99
+ dataset_info = get_dataset_info(dataset_path)
100
+
101
+ # Use provided charts or find them in output directory
102
+ if charts is None:
103
+ charts = []
104
+ # Get the absolute path to the output directory
105
+ current_dir = os.path.dirname(os.path.abspath(__file__))
106
+ output_dir = os.path.join(current_dir, "output")
107
+
108
+ if os.path.exists(output_dir):
109
+ # Find all image files in the output directory
110
+ chart_patterns = ['*.png', '*.jpg', '*.jpeg', '*.svg', '*.pdf']
111
+ for pattern in chart_patterns:
112
+ chart_files = glob.glob(os.path.join(output_dir, pattern))
113
+ charts.extend(chart_files)
114
+
115
+ # Sort charts by creation time (newest first)
116
+ charts.sort(key=os.path.getctime, reverse=True)
117
+
118
+ print(f"Found {len(charts)} charts: {charts}")
119
+
120
+ # Create chart information for the prompt
121
+ chart_info = ""
122
+ if charts:
123
+ chart_info = f"\n\nCHARTS GENERATED:\n"
124
+ for i, chart in enumerate(charts, 1):
125
+ chart_info += f"Chart {i}: {chart}\n"
126
+ chart_info += "\nThese charts contain visual representations of the data analysis. Please refer to them when providing your analysis."
127
+ else:
128
+ chart_info = "\n\nNo charts were generated for this analysis."
129
+
130
+ return f"""
131
+ You are a senior data analyst.
132
+
133
+ Here is the dataset information:
134
+ {dataset_info}
135
+
136
+ Here is the question that was asked:
137
+ {question}
138
+
139
+ Here is the result of the execution:
140
+ {execution_result}
141
+ {chart_info}
142
+
143
+ Your task is to analyze the question, the result of the execution, and the charts to answer the question comprehensively.
144
+
145
+ RETURN FORMAT:
146
+ - Return ONLY the analysis in a single string
147
+ - Do NOT include any other text or comments
148
+ - Start directly with the analysis
149
+ - If charts were generated, reference them in your analysis
150
+ """
src/question-answer/states.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Type definitions for the LangGraph project."""
2
+
3
+ from typing import TypedDict, Annotated
4
+ from langchain_core.messages import AnyMessage
5
+ from langgraph.graph.message import add_messages
6
+
7
+ class State(TypedDict):
8
+
9
+ messages: Annotated[list[AnyMessage], add_messages]
10
+ dataset_path: str
11
+ codes: Annotated[list[str], add_messages]
12
+ charts: Annotated[list[str], add_messages]
13
+ report: str
src/question-answer/tools.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import base64
3
+ import os
4
+ from daytona import Daytona
5
+ load_dotenv()
6
+
7
+ def clean_code(code: str) -> str:
8
+ """Clean the code - remove any file paths or non-Python content"""
9
+ lines = code.split('\n')
10
+ python_start = -1
11
+
12
+ # Look for Python code starting patterns
13
+ for i, line in enumerate(lines):
14
+ stripped = line.strip()
15
+ # Skip markdown code blocks
16
+ if stripped.startswith('```python') or stripped.startswith('```'):
17
+ continue
18
+ # Look for actual Python code
19
+ if (stripped.startswith('import ') or
20
+ stripped.startswith('from ') or
21
+ stripped.startswith('# ') or
22
+ stripped.startswith('"""') or
23
+ stripped.startswith("'''") or
24
+ (stripped and not '/' in stripped and not stripped.endswith('"'))):
25
+ python_start = i
26
+ break
27
+
28
+ if python_start > 0:
29
+ cleaned_code = '\n'.join(lines[python_start:])
30
+ print(f"Cleaned code by removing {python_start} lines from the beginning")
31
+ else:
32
+ cleaned_code = code
33
+
34
+ # Remove any remaining markdown markers
35
+ cleaned_code = cleaned_code.replace('```python', '').replace('```', '')
36
+
37
+ return cleaned_code
38
+
39
+ def print_logs(result) -> None:
40
+ """Print the logs of the execution"""
41
+ if hasattr(result, 'stdout') and result.stdout:
42
+ print("STDOUT:")
43
+ print(result.stdout)
44
+ if hasattr(result, 'stderr') and result.stderr:
45
+ print("STDERR:")
46
+ print(result.stderr)
47
+
48
+ def cleanup_sandboxes():
49
+ """Clean up all sandboxes to free resources"""
50
+ try:
51
+ daytona = Daytona()
52
+ # Try to get existing sandbox and delete it
53
+ try:
54
+ existing_sandbox = daytona.get_current_sandbox()
55
+ existing_sandbox.delete()
56
+ print(f"Cleaned up existing sandbox: {existing_sandbox.id}")
57
+ except:
58
+ print("No existing sandbox to clean up")
59
+ except Exception as e:
60
+ print(f"Warning: Could not clean up sandboxes: {e}")
61
+
62
+ def run_code(code: str, dataset_path: str) -> dict:
63
+ """Run code in a Daytona sandbox"""
64
+
65
+ cleaned_code = clean_code(code)
66
+
67
+ # initialize daytona client
68
+ daytona = Daytona()
69
+
70
+ # try to get existing sandbox
71
+ try:
72
+ sandbox = daytona.get_current_sandbox()
73
+ print("Using existing sandbox")
74
+ except:
75
+ try:
76
+ sandbox = daytona.create()
77
+ print("Created new sandbox")
78
+ except Exception as e:
79
+ if "CPU quota exceeded" in str(e) or "disk quota exceeded" in str(e).lower():
80
+ print("Quota exceeded, cleaning up and trying again...")
81
+ cleanup_sandboxes()
82
+ try:
83
+ sandbox = daytona.create()
84
+ print("Created new sandbox after cleanup")
85
+ except Exception as e2:
86
+ print(f"Still failed after cleanup: {e2}")
87
+ raise e2
88
+ else:
89
+ raise e
90
+
91
+ # Upload the dataset to the sandbox using file system operations
92
+ try:
93
+ # Upload the original dataset to the sandbox
94
+ sandbox_datapath = f"tmp/{os.path.basename(dataset_path)}"
95
+ sandbox.fs.upload_file(dataset_path, sandbox_datapath)
96
+ print(f"Uploaded {dataset_path} to {sandbox_datapath}")
97
+
98
+ # Replace the original file path in the code with the sandbox path
99
+ cleaned_code = cleaned_code.replace(dataset_path, sandbox_datapath)
100
+ print(f"Updated code to use sandbox path: {sandbox_datapath}")
101
+
102
+ except Exception as e:
103
+ print(f"WARNING: Could not upload {dataset_path} to sandbox: {e}")
104
+ # If the file doesn't exist locally, continue without uploading
105
+ if "does not exist" in str(e) or "not found" in str(e).lower():
106
+ print(f"File {dataset_path} does not exist locally, continuing without upload")
107
+ sandbox_datapath = dataset_path # Use original path as fallback
108
+ else:
109
+ raise e
110
+
111
+ #########################################################
112
+ ################# Running the code ######################
113
+ #########################################################
114
+
115
+ try:
116
+ # Install only essential dependencies to speed up execution
117
+ print("Installing essential dependencies...")
118
+ install_deps_code = """
119
+ import subprocess
120
+ import sys
121
+
122
+ # Install only essential packages
123
+ packages = ['matplotlib', 'pandas', 'numpy']
124
+ for package in packages:
125
+ try:
126
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '--quiet'])
127
+ print(f"Installed {package}")
128
+ except Exception as e:
129
+ print(f"Failed to install {package}: {e}")
130
+ """
131
+
132
+ try:
133
+ deps_result = sandbox.process.code_run(install_deps_code)
134
+ print("Dependencies installation completed")
135
+ except Exception as e:
136
+ print(f"Warning: Could not install dependencies: {e}")
137
+
138
+ # Run the code in the sandbox
139
+ print("-----------------------------------------------------------------------------------")
140
+ print('Running the analysis code in the sandbox....')
141
+
142
+ result = sandbox.process.code_run(cleaned_code)
143
+ print('Code execution finished!')
144
+
145
+ # Check for execution errors
146
+ if result.exit_code != 0:
147
+ print(f"EXECUTION ERROR: {result.result}")
148
+ return {
149
+ "success": False,
150
+ "execution": f"Execution failed with error: {result.result}"
151
+ }
152
+
153
+ except Exception as e:
154
+ print(f"Error running code: {e}")
155
+ return {
156
+ "success": False,
157
+ "execution": str(e)
158
+ }
159
+
160
+ print("-----------------------------------------------------------------------------------")
161
+ print("Checking for files in the sandbox...")
162
+ print("-----------------------------------------------------------------------------------")
163
+
164
+ #########################################################
165
+ ############# Post-execution file checking ##############
166
+ #########################################################
167
+
168
+ # Check what files were created after code execution
169
+ try:
170
+ post_debug_code = """
171
+ import os
172
+ import glob
173
+ print("\\n=== FILES AFTER CODE EXECUTION ===")
174
+ print(f"Current working directory: {os.getcwd()}")
175
+ print("Files in current directory:")
176
+ for f in os.listdir('.'):
177
+ print(f" {f}")
178
+ """
179
+
180
+ post_debug_result = sandbox.process.code_run(post_debug_code)
181
+ print("Post-execution debug info:", post_debug_result.result)
182
+ except Exception as e:
183
+ print(f"Could not check post-execution files: {e}")
184
+
185
+ #########################################################
186
+ ############# Checking for charts in the sandbox ########
187
+ #########################################################
188
+
189
+ # Ensure output directory exists
190
+ os.makedirs("output", exist_ok=True)
191
+
192
+ # Check for plots - look for saved plot files in the sandbox
193
+ charts_count = 0
194
+
195
+ # Look for common plot file patterns in both current directory and /tmp/
196
+ search_directories = ['.', '/tmp']
197
+ plot_patterns = ['*.png', '*.jpg', '*.jpeg', '*.svg', '*.pdf']
198
+
199
+ for search_dir in search_directories:
200
+ print(f"Searching for charts in {search_dir}...")
201
+ for pattern in plot_patterns:
202
+ try:
203
+ # List files matching the pattern in the specific directory
204
+ list_cmd = f"import glob; import os; files = glob.glob('{search_dir}/{pattern}'); print('\\n'.join(files))"
205
+ plot_files_result = sandbox.process.code_run(list_cmd)
206
+
207
+ if plot_files_result.result.strip():
208
+ plot_files = plot_files_result.result.strip().split('\n')
209
+ for i, plot_file in enumerate(plot_files):
210
+ try:
211
+ # Create filename for the chart
212
+ chart_filename = f"chart_{charts_count + 1}.{plot_file.strip().split('.')[-1]}"
213
+
214
+ # Download the plot file from sandbox
215
+ sandbox.fs.download_file(plot_file.strip(), f"output/{chart_filename}")
216
+
217
+ print(f"Downloaded chart: output/{chart_filename}")
218
+ charts_count += 1
219
+
220
+ except Exception as e:
221
+ print(f"Error processing chart {plot_file}: {e}")
222
+ except Exception as e:
223
+ print(f"Error searching for {pattern} files in {search_dir}: {e}")
224
+
225
+ if charts_count == 0:
226
+ print("WARNING: No charts were downloaded.")
227
+ # Let's also check what files actually exist in the sandbox
228
+ try:
229
+ debug_cmd = """
230
+ import os
231
+ import glob
232
+ print("\\n=== DEBUGGING CHART DETECTION ===")
233
+ print("Current working directory:", os.getcwd())
234
+ print("Files in current directory:")
235
+ for f in os.listdir('.'):
236
+ print(f" {f}")
237
+ print("\\nAll image files found:")
238
+ for pattern in ['*.png', '*.jpg', '*.jpeg', '*.svg', '*.pdf']:
239
+ files = glob.glob(pattern)
240
+ if files:
241
+ print(f" {pattern}: {files}")
242
+ else:
243
+ print(f" {pattern}: No files found")
244
+ """
245
+ debug_result = sandbox.process.code_run(debug_cmd)
246
+ print("Chart detection debug:", debug_result.result)
247
+ except Exception as e:
248
+ print(f"Could not run chart detection debug: {e}")
249
+ else:
250
+ print(f"Successfully downloaded {charts_count} charts")
251
+
252
+ # Clean up the sandbox to free disk space
253
+ try:
254
+ sandbox.delete()
255
+ print("Sandbox cleaned up to free disk space")
256
+ except Exception as e:
257
+ print(f"Warning: Could not clean up sandbox: {e}")
258
+
259
+ return {
260
+ "success": True,
261
+ "execution": result.result,
262
+ "charts": [f"output/chart_{i+1}.png" for i in range(charts_count)]
263
+ }
264
+
src/question-answer/utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from pathlib import Path
4
+ import json
5
+ import os
6
+ from datetime import datetime
7
+ import re
8
+
9
+ def format_execution_results(execution_result: str) -> str:
10
+ """
11
+ Format execution results to be more appealing and concise.
12
+ """
13
+ if not execution_result:
14
+ return "No execution results available."
15
+
16
+ # Remove technical warnings and error messages
17
+ lines = execution_result.split('\n')
18
+ cleaned_lines = []
19
+ skip_next_lines = 0
20
+
21
+ for i, line in enumerate(lines):
22
+ # Skip lines if we're in a warning block
23
+ if skip_next_lines > 0:
24
+ skip_next_lines -= 1
25
+ continue
26
+
27
+ # Skip warning lines and their content
28
+ if any(warning in line.lower() for warning in [
29
+ 'futurewarning', 'userwarning', 'deprecationwarning',
30
+ 'target_code:', 'warning:', 'error:', 'passing `palette`'
31
+ ]):
32
+ # Skip the warning line and the next few lines that are part of the warning
33
+ skip_next_lines = 2
34
+ continue
35
+
36
+ # Skip empty lines at the beginning
37
+ if not cleaned_lines and not line.strip():
38
+ continue
39
+
40
+ # Skip lines that are just function calls or technical details
41
+ if line.strip() in ['sns.barplot(', 'sns.barplot(']:
42
+ continue
43
+
44
+ cleaned_lines.append(line)
45
+
46
+ # Join and clean up
47
+ result = '\n'.join(cleaned_lines).strip()
48
+
49
+ # Remove the "Code execution result:" prefix if present
50
+ if result.startswith("Code execution result: "):
51
+ result = result[23:]
52
+
53
+ # If the result is too long, truncate it
54
+ if len(result) > 800:
55
+ result = result[:800] + "\n... (truncated for readability)"
56
+
57
+ return f"```\n{result}\n```"
58
+
59
+ def get_dataset_info(dataset_path: str) -> str:
60
+ """
61
+ Get comprehensive information about the dataset.
62
+ """
63
+ try:
64
+ # Load the dataset
65
+ file_path = Path(dataset_path)
66
+ if not file_path.exists():
67
+ return f"Error: Dataset file '{dataset_path}' not found."
68
+
69
+ # Read dataset based on file extension
70
+ if file_path.suffix.lower() == '.csv':
71
+ df = pd.read_csv(dataset_path)
72
+ elif file_path.suffix.lower() in ['.xlsx', '.xls']:
73
+ df = pd.read_excel(dataset_path)
74
+ elif file_path.suffix.lower() == '.json':
75
+ df = pd.read_json(dataset_path)
76
+ elif file_path.suffix.lower() == '.parquet':
77
+ df = pd.read_parquet(dataset_path)
78
+ else:
79
+ return f"Error: Unsupported file format '{file_path.suffix}'"
80
+
81
+ # Gather comprehensive dataset information
82
+ info = {
83
+ 'file_name': file_path.name,
84
+ 'file_size': f"{file_path.stat().st_size / (1024*1024):.2f} MB",
85
+ 'shape': f"{df.shape[0]} rows x {df.shape[1]} columns",
86
+ 'columns': list(df.columns),
87
+ 'data_types': df.dtypes.to_dict(),
88
+ 'missing_values': df.isnull().sum().to_dict(),
89
+ 'missing_percentage': (df.isnull().sum() / len(df) * 100).round(2).to_dict(),
90
+ 'duplicate_rows': df.duplicated().sum(),
91
+ 'memory_usage': f"{df.memory_usage(deep=True).sum() / (1024*1024):.2f} MB",
92
+ 'numeric_columns': list(df.select_dtypes(include=[np.number]).columns),
93
+ 'categorical_columns': list(df.select_dtypes(include=['object', 'category']).columns),
94
+ 'datetime_columns': list(df.select_dtypes(include=['datetime64']).columns)
95
+ }
96
+
97
+ # Add statistical summary for numeric columns
98
+ if info['numeric_columns']:
99
+ numeric_stats = df[info['numeric_columns']].describe().to_dict()
100
+ info['numeric_statistics'] = numeric_stats
101
+
102
+ # Add unique value counts for categorical columns (sample)
103
+ categorical_info = {}
104
+ for col in info['categorical_columns'][:5]: # Limit to first 5 categorical columns
105
+ unique_count = df[col].nunique()
106
+ categorical_info[col] = {
107
+ 'unique_values': unique_count,
108
+ 'sample_values': list(df[col].dropna().unique()[:10]) # First 10 unique values
109
+ }
110
+ info['categorical_info'] = categorical_info
111
+
112
+ # Identify potential data quality issues
113
+ issues = []
114
+
115
+ # Check for columns with high missing values
116
+ high_missing = [col for col, pct in info['missing_percentage'].items() if pct > 50]
117
+ if high_missing:
118
+ issues.append(f"High missing values (>50%): {high_missing}")
119
+
120
+ # Check for potential outliers in numeric columns
121
+ for col in info['numeric_columns']:
122
+ q1 = df[col].quantile(0.25)
123
+ q3 = df[col].quantile(0.75)
124
+ iqr = q3 - q1
125
+ outliers = df[(df[col] < (q1 - 1.5 * iqr)) | (df[col] > (q3 + 1.5 * iqr))][col].count()
126
+ if outliers > 0:
127
+ issues.append(f"Potential outliers in '{col}': {outliers} values")
128
+
129
+ # Check for inconsistent data types
130
+ for col in info['categorical_columns']:
131
+ if df[col].dtype == 'object':
132
+ # Check if column contains mixed numeric and string values
133
+ sample_values = df[col].dropna().astype(str).head(100)
134
+ numeric_count = sum(1 for val in sample_values if val.replace('.', '').replace('-', '').isdigit())
135
+ if 0 < numeric_count < len(sample_values):
136
+ issues.append(f"Mixed data types in '{col}': contains both numeric and text values")
137
+
138
+ info['potential_issues'] = issues
139
+
140
+ return json.dumps(info, indent=2, default=str)
141
+
142
+ except Exception as e:
143
+ return f"Error analyzing dataset: {str(e)}"
144
+
145
+ def create_markdown_report(question: str, analysis: str, charts: list, execution_result: str) -> str:
146
+ """
147
+ Create a simple markdown report with analysis and key findings.
148
+ """
149
+ report = f"""## Analysis
150
+ {analysis}
151
+
152
+ ## Key Findings
153
+ {format_execution_results(execution_result)}
154
+ """
155
+
156
+ return report
157
+
158
+ def save_markdown_report(report_content: str) -> str:
159
+ """
160
+ Save the markdown report to a file and return the filename.
161
+ """
162
+ # Ensure output directory exists
163
+ output_dir = "output"
164
+ os.makedirs(output_dir, exist_ok=True)
165
+
166
+ # Generate filename with timestamp
167
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
168
+ filename = f"analysis_report_{timestamp}.md"
169
+ filepath = os.path.join(output_dir, filename)
170
+
171
+ # Write the report
172
+ with open(filepath, 'w', encoding='utf-8') as f:
173
+ f.write(report_content)
174
+
175
+ print(f"Report saved: {filepath}")
176
+ return filepath
177
+