File size: 9,067 Bytes
e13d87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from typing import Dict, List, Any, Optional, Tuple, Union
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import io
import base64
import numpy as np
from llama_index.tools import FunctionTool
from pathlib import Path

# Configure matplotlib for non-interactive environments
matplotlib.use('Agg')

class VisualizationTools:
    """Tools for creating visualizations from CSV data."""
    
    def __init__(self, csv_directory: str):
        """Initialize with directory containing CSV files."""
        self.csv_directory = csv_directory
        self.dataframes = {}
        self.tools = self._create_tools()
        self.figure_size = (10, 6)
        self.dpi = 100
    
    def _load_dataframe(self, filename: str) -> pd.DataFrame:
        """Load a CSV file as DataFrame, with caching."""
        if filename not in self.dataframes:
            file_path = Path(self.csv_directory) / filename
            if not file_path.exists() and not filename.endswith('.csv'):
                file_path = Path(self.csv_directory) / f"{filename}.csv"
            
            if file_path.exists():
                self.dataframes[filename] = pd.read_csv(file_path)
            else:
                raise ValueError(f"CSV file not found: {filename}")
                
        return self.dataframes[filename]
    
    def _create_tools(self) -> List[FunctionTool]:
        """Create LlamaIndex function tools for visualizations."""
        tools = [
            FunctionTool.from_defaults(
                name="create_line_chart",
                description="Create a line chart from CSV data",
                fn=self.create_line_chart
            ),
            FunctionTool.from_defaults(
                name="create_bar_chart",
                description="Create a bar chart from CSV data",
                fn=self.create_bar_chart
            ),
            FunctionTool.from_defaults(
                name="create_scatter_plot",
                description="Create a scatter plot from CSV data",
                fn=self.create_scatter_plot
            ),
            FunctionTool.from_defaults(
                name="create_histogram",
                description="Create a histogram from CSV data",
                fn=self.create_histogram
            ),
            FunctionTool.from_defaults(
                name="create_pie_chart",
                description="Create a pie chart from CSV data",
                fn=self.create_pie_chart
            )
        ]
        return tools
    
    def get_tools(self) -> List[FunctionTool]:
        """Get all available visualization tools."""
        return self.tools
    
    def _figure_to_base64(self, fig) -> str:
        """Convert matplotlib figure to base64 encoded string."""
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=self.dpi)
        buf.seek(0)
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)
        return img_str
    
    # Visualization tool implementations
    def create_line_chart(self, filename: str, x_column: str, y_column: str, 

                          title: str = None, limit: int = 50) -> Dict[str, Any]:
        """Create a line chart visualization."""
        df = self._load_dataframe(filename)
        
        # Limit data points if needed
        if len(df) > limit:
            df = df.head(limit)
            
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        # Create line chart
        ax.plot(df[x_column], df[y_column], marker='o', linestyle='-')
        
        # Set labels and title
        ax.set_xlabel(x_column)
        ax.set_ylabel(y_column)
        ax.set_title(title or f"{y_column} vs {x_column}")
        ax.grid(True)
        
        # Convert to base64
        img_str = self._figure_to_base64(fig)
        
        return {
            "chart_type": "line",
            "x_column": x_column,
            "y_column": y_column,
            "data_points": len(df),
            "image": img_str
        }
    
    def create_bar_chart(self, filename: str, x_column: str, y_column: str, 

                         title: str = None, limit: int = 20) -> Dict[str, Any]:
        """Create a bar chart visualization."""
        df = self._load_dataframe(filename)
        
        # Limit categories if needed
        if len(df) > limit:
            df = df.head(limit)
            
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        # Create bar chart
        ax.bar(df[x_column], df[y_column])
        
        # Set labels and title
        ax.set_xlabel(x_column)
        ax.set_ylabel(y_column)
        ax.set_title(title or f"{y_column} by {x_column}")
        
        # Rotate x labels if there are many categories
        if len(df) > 5:
            plt.xticks(rotation=45, ha='right')
            
        plt.tight_layout()
        
        # Convert to base64
        img_str = self._figure_to_base64(fig)
        
        return {
            "chart_type": "bar",
            "x_column": x_column,
            "y_column": y_column,
            "categories": len(df),
            "image": img_str
        }
    
    def create_scatter_plot(self, filename: str, x_column: str, y_column: str, 

                           color_column: str = None, title: str = None) -> Dict[str, Any]:
        """Create a scatter plot visualization."""
        df = self._load_dataframe(filename)
            
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        # Create scatter plot
        if color_column and color_column in df.columns:
            scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7)
            plt.colorbar(scatter, ax=ax, label=color_column)
        else:
            ax.scatter(df[x_column], df[y_column], alpha=0.7)
        
        # Set labels and title
        ax.set_xlabel(x_column)
        ax.set_ylabel(y_column)
        ax.set_title(title or f"{y_column} vs {x_column}")
        ax.grid(True, linestyle='--', alpha=0.7)
        
        # Convert to base64
        img_str = self._figure_to_base64(fig)
        
        return {
            "chart_type": "scatter",
            "x_column": x_column,
            "y_column": y_column,
            "color_column": color_column,
            "data_points": len(df),
            "image": img_str
        }
    
    def create_histogram(self, filename: str, column: str, bins: int = 10, 

                        title: str = None) -> Dict[str, Any]:
        """Create a histogram visualization."""
        df = self._load_dataframe(filename)
            
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        # Create histogram
        ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black')
        
        # Set labels and title
        ax.set_xlabel(column)
        ax.set_ylabel('Frequency')
        ax.set_title(title or f"Distribution of {column}")
        ax.grid(True, linestyle='--', alpha=0.7)
        
        # Convert to base64
        img_str = self._figure_to_base64(fig)
        
        return {
            "chart_type": "histogram",
            "column": column,
            "bins": bins,
            "data_points": len(df),
            "image": img_str
        }
    
    def create_pie_chart(self, filename: str, label_column: str, value_column: str = None, 

                        title: str = None, limit: int = 10) -> Dict[str, Any]:
        """Create a pie chart visualization."""
        df = self._load_dataframe(filename)
        
        # If value column not provided, count occurrences of each label
        if value_column is None:
            data = df[label_column].value_counts().head(limit)
            labels = data.index.tolist()
            values = data.values.tolist()
        else:
            # Group by label and sum values
            grouped = df.groupby(label_column)[value_column].sum().reset_index()
            # Limit to top categories
            grouped = grouped.nlargest(limit, value_column)
            labels = grouped[label_column].tolist()
            values = grouped[value_column].tolist()
            
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        # Create pie chart
        ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True)
        ax.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle
        
        # Set title
        ax.set_title(title or f"Distribution of {label_column}")
        
        # Convert to base64
        img_str = self._figure_to_base64(fig)
        
        return {
            "chart_type": "pie",
            "label_column": label_column,
            "value_column": value_column,
            "categories": len(labels),
            "image": img_str
        }