Chamin09 commited on
Commit
f78ea13
·
verified ·
1 Parent(s): bbd9a98

Create visualization.py

Browse files
Files changed (1) hide show
  1. tools/visualization.py +240 -0
tools/visualization.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Tuple, Union
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ import io
6
+ import base64
7
+ import numpy as np
8
+ from pathlib import Path
9
+
10
+ # Configure matplotlib for non-interactive environments
11
+ matplotlib.use('Agg')
12
+
13
+ class VisualizationTools:
14
+ """Tools for creating visualizations from CSV data."""
15
+
16
+ def __init__(self, csv_directory: str):
17
+ """Initialize with directory containing CSV files."""
18
+ self.csv_directory = csv_directory
19
+ self.dataframes = {}
20
+ self.figure_size = (10, 6)
21
+ self.dpi = 100
22
+
23
+ def _load_dataframe(self, filename: str) -> pd.DataFrame:
24
+ """Load a CSV file as DataFrame, with caching."""
25
+ if filename not in self.dataframes:
26
+ file_path = Path(self.csv_directory) / filename
27
+ if not file_path.exists() and not filename.endswith('.csv'):
28
+ file_path = Path(self.csv_directory) / f"{filename}.csv"
29
+
30
+ if file_path.exists():
31
+ self.dataframes[filename] = pd.read_csv(file_path)
32
+ else:
33
+ raise ValueError(f"CSV file not found: {filename}")
34
+
35
+ return self.dataframes[filename]
36
+
37
+ def get_tools(self) -> List[Dict[str, Any]]:
38
+ """Get all available visualization tools."""
39
+ tools = [
40
+ {
41
+ "name": "create_line_chart",
42
+ "description": "Create a line chart from CSV data",
43
+ "function": self.create_line_chart
44
+ },
45
+ {
46
+ "name": "create_bar_chart",
47
+ "description": "Create a bar chart from CSV data",
48
+ "function": self.create_bar_chart
49
+ },
50
+ {
51
+ "name": "create_scatter_plot",
52
+ "description": "Create a scatter plot from CSV data",
53
+ "function": self.create_scatter_plot
54
+ },
55
+ {
56
+ "name": "create_histogram",
57
+ "description": "Create a histogram from CSV data",
58
+ "function": self.create_histogram
59
+ },
60
+ {
61
+ "name": "create_pie_chart",
62
+ "description": "Create a pie chart from CSV data",
63
+ "function": self.create_pie_chart
64
+ }
65
+ ]
66
+ return tools
67
+
68
+ def _figure_to_base64(self, fig) -> str:
69
+ """Convert matplotlib figure to base64 encoded string."""
70
+ buf = io.BytesIO()
71
+ fig.savefig(buf, format='png', dpi=self.dpi)
72
+ buf.seek(0)
73
+ img_str = base64.b64encode(buf.read()).decode('utf-8')
74
+ plt.close(fig)
75
+ return img_str
76
+
77
+ # Visualization tool implementations
78
+ def create_line_chart(self, filename: str, x_column: str, y_column: str,
79
+ title: str = None, limit: int = 50) -> Dict[str, Any]:
80
+ """Create a line chart visualization."""
81
+ df = self._load_dataframe(filename)
82
+
83
+ # Limit data points if needed
84
+ if len(df) > limit:
85
+ df = df.head(limit)
86
+
87
+ fig, ax = plt.subplots(figsize=self.figure_size)
88
+
89
+ # Create line chart
90
+ ax.plot(df[x_column], df[y_column], marker='o', linestyle='-')
91
+
92
+ # Set labels and title
93
+ ax.set_xlabel(x_column)
94
+ ax.set_ylabel(y_column)
95
+ ax.set_title(title or f"{y_column} vs {x_column}")
96
+ ax.grid(True)
97
+
98
+ # Convert to base64
99
+ img_str = self._figure_to_base64(fig)
100
+
101
+ return {
102
+ "chart_type": "line",
103
+ "x_column": x_column,
104
+ "y_column": y_column,
105
+ "data_points": len(df),
106
+ "image": img_str
107
+ }
108
+
109
+ def create_bar_chart(self, filename: str, x_column: str, y_column: str,
110
+ title: str = None, limit: int = 20) -> Dict[str, Any]:
111
+ """Create a bar chart visualization."""
112
+ df = self._load_dataframe(filename)
113
+
114
+ # Limit categories if needed
115
+ if len(df) > limit:
116
+ df = df.head(limit)
117
+
118
+ fig, ax = plt.subplots(figsize=self.figure_size)
119
+
120
+ # Create bar chart
121
+ ax.bar(df[x_column], df[y_column])
122
+
123
+ # Set labels and title
124
+ ax.set_xlabel(x_column)
125
+ ax.set_ylabel(y_column)
126
+ ax.set_title(title or f"{y_column} by {x_column}")
127
+
128
+ # Rotate x labels if there are many categories
129
+ if len(df) > 5:
130
+ plt.xticks(rotation=45, ha='right')
131
+
132
+ plt.tight_layout()
133
+
134
+ # Convert to base64
135
+ img_str = self._figure_to_base64(fig)
136
+
137
+ return {
138
+ "chart_type": "bar",
139
+ "x_column": x_column,
140
+ "y_column": y_column,
141
+ "categories": len(df),
142
+ "image": img_str
143
+ }
144
+
145
+ def create_scatter_plot(self, filename: str, x_column: str, y_column: str,
146
+ color_column: str = None, title: str = None) -> Dict[str, Any]:
147
+ """Create a scatter plot visualization."""
148
+ df = self._load_dataframe(filename)
149
+
150
+ fig, ax = plt.subplots(figsize=self.figure_size)
151
+
152
+ # Create scatter plot
153
+ if color_column and color_column in df.columns:
154
+ scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7)
155
+ plt.colorbar(scatter, ax=ax, label=color_column)
156
+ else:
157
+ ax.scatter(df[x_column], df[y_column], alpha=0.7)
158
+
159
+ # Set labels and title
160
+ ax.set_xlabel(x_column)
161
+ ax.set_ylabel(y_column)
162
+ ax.set_title(title or f"{y_column} vs {x_column}")
163
+ ax.grid(True, linestyle='--', alpha=0.7)
164
+
165
+ # Convert to base64
166
+ img_str = self._figure_to_base64(fig)
167
+
168
+ return {
169
+ "chart_type": "scatter",
170
+ "x_column": x_column,
171
+ "y_column": y_column,
172
+ "color_column": color_column,
173
+ "data_points": len(df),
174
+ "image": img_str
175
+ }
176
+
177
+ def create_histogram(self, filename: str, column: str, bins: int = 10,
178
+ title: str = None) -> Dict[str, Any]:
179
+ """Create a histogram visualization."""
180
+ df = self._load_dataframe(filename)
181
+
182
+ fig, ax = plt.subplots(figsize=self.figure_size)
183
+
184
+ # Create histogram
185
+ ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black')
186
+
187
+ # Set labels and title
188
+ ax.set_xlabel(column)
189
+ ax.set_ylabel('Frequency')
190
+ ax.set_title(title or f"Distribution of {column}")
191
+ ax.grid(True, linestyle='--', alpha=0.7)
192
+
193
+ # Convert to base64
194
+ img_str = self._figure_to_base64(fig)
195
+
196
+ return {
197
+ "chart_type": "histogram",
198
+ "column": column,
199
+ "bins": bins,
200
+ "data_points": len(df),
201
+ "image": img_str
202
+ }
203
+
204
+ def create_pie_chart(self, filename: str, label_column: str, value_column: str = None,
205
+ title: str = None, limit: int = 10) -> Dict[str, Any]:
206
+ """Create a pie chart visualization."""
207
+ df = self._load_dataframe(filename)
208
+
209
+ # If value column not provided, count occurrences of each label
210
+ if value_column is None:
211
+ data = df[label_column].value_counts().head(limit)
212
+ labels = data.index.tolist()
213
+ values = data.values.tolist()
214
+ else:
215
+ # Group by label and sum values
216
+ grouped = df.groupby(label_column)[value_column].sum().reset_index()
217
+ # Limit to top categories
218
+ grouped = grouped.nlargest(limit, value_column)
219
+ labels = grouped[label_column].tolist()
220
+ values = grouped[value_column].tolist()
221
+
222
+ fig, ax = plt.subplots(figsize=self.figure_size)
223
+
224
+ # Create pie chart
225
+ ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True)
226
+ ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
227
+
228
+ # Set title
229
+ ax.set_title(title or f"Distribution of {label_column}")
230
+
231
+ # Convert to base64
232
+ img_str = self._figure_to_base64(fig)
233
+
234
+ return {
235
+ "chart_type": "pie",
236
+ "label_column": label_column,
237
+ "value_column": value_column,
238
+ "categories": len(labels),
239
+ "image": img_str
240
+ }