|
from typing import Dict, List, Any, Optional, Tuple, Union |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import io |
|
import base64 |
|
import numpy as np |
|
from pathlib import Path |
|
|
|
|
|
matplotlib.use('Agg') |
|
|
|
class VisualizationTools: |
|
"""Tools for creating visualizations from CSV data.""" |
|
|
|
def __init__(self, csv_directory: str): |
|
"""Initialize with directory containing CSV files.""" |
|
self.csv_directory = csv_directory |
|
self.dataframes = {} |
|
self.figure_size = (10, 6) |
|
self.dpi = 100 |
|
|
|
def _load_dataframe(self, filename: str) -> pd.DataFrame: |
|
"""Load a CSV file as DataFrame, with caching.""" |
|
if filename not in self.dataframes: |
|
file_path = Path(self.csv_directory) / filename |
|
if not file_path.exists() and not filename.endswith('.csv'): |
|
file_path = Path(self.csv_directory) / f"{filename}.csv" |
|
|
|
if file_path.exists(): |
|
self.dataframes[filename] = pd.read_csv(file_path) |
|
else: |
|
raise ValueError(f"CSV file not found: {filename}") |
|
|
|
return self.dataframes[filename] |
|
|
|
def get_tools(self) -> List[Dict[str, Any]]: |
|
"""Get all available visualization tools.""" |
|
tools = [ |
|
{ |
|
"name": "create_line_chart", |
|
"description": "Create a line chart from CSV data", |
|
"function": self.create_line_chart |
|
}, |
|
{ |
|
"name": "create_bar_chart", |
|
"description": "Create a bar chart from CSV data", |
|
"function": self.create_bar_chart |
|
}, |
|
{ |
|
"name": "create_scatter_plot", |
|
"description": "Create a scatter plot from CSV data", |
|
"function": self.create_scatter_plot |
|
}, |
|
{ |
|
"name": "create_histogram", |
|
"description": "Create a histogram from CSV data", |
|
"function": self.create_histogram |
|
}, |
|
{ |
|
"name": "create_pie_chart", |
|
"description": "Create a pie chart from CSV data", |
|
"function": self.create_pie_chart |
|
} |
|
] |
|
return tools |
|
|
|
def _figure_to_base64(self, fig) -> str: |
|
"""Convert matplotlib figure to base64 encoded string.""" |
|
buf = io.BytesIO() |
|
fig.savefig(buf, format='png', dpi=self.dpi) |
|
buf.seek(0) |
|
img_str = base64.b64encode(buf.read()).decode('utf-8') |
|
plt.close(fig) |
|
return img_str |
|
|
|
|
|
def create_line_chart(self, filename: str, x_column: str, y_column: str, |
|
title: str = None, limit: int = 50) -> Dict[str, Any]: |
|
"""Create a line chart visualization.""" |
|
df = self._load_dataframe(filename) |
|
|
|
|
|
if len(df) > limit: |
|
df = df.head(limit) |
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size) |
|
|
|
|
|
ax.plot(df[x_column], df[y_column], marker='o', linestyle='-') |
|
|
|
|
|
ax.set_xlabel(x_column) |
|
ax.set_ylabel(y_column) |
|
ax.set_title(title or f"{y_column} vs {x_column}") |
|
ax.grid(True) |
|
|
|
|
|
img_str = self._figure_to_base64(fig) |
|
|
|
return { |
|
"chart_type": "line", |
|
"x_column": x_column, |
|
"y_column": y_column, |
|
"data_points": len(df), |
|
"image": img_str |
|
} |
|
|
|
def create_bar_chart(self, filename: str, x_column: str, y_column: str, |
|
title: str = None, limit: int = 20) -> Dict[str, Any]: |
|
"""Create a bar chart visualization.""" |
|
df = self._load_dataframe(filename) |
|
|
|
|
|
if len(df) > limit: |
|
df = df.head(limit) |
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size) |
|
|
|
|
|
ax.bar(df[x_column], df[y_column]) |
|
|
|
|
|
ax.set_xlabel(x_column) |
|
ax.set_ylabel(y_column) |
|
ax.set_title(title or f"{y_column} by {x_column}") |
|
|
|
|
|
if len(df) > 5: |
|
plt.xticks(rotation=45, ha='right') |
|
|
|
plt.tight_layout() |
|
|
|
|
|
img_str = self._figure_to_base64(fig) |
|
|
|
return { |
|
"chart_type": "bar", |
|
"x_column": x_column, |
|
"y_column": y_column, |
|
"categories": len(df), |
|
"image": img_str |
|
} |
|
|
|
def create_scatter_plot(self, filename: str, x_column: str, y_column: str, |
|
color_column: str = None, title: str = None) -> Dict[str, Any]: |
|
"""Create a scatter plot visualization.""" |
|
df = self._load_dataframe(filename) |
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size) |
|
|
|
|
|
if color_column and color_column in df.columns: |
|
scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7) |
|
plt.colorbar(scatter, ax=ax, label=color_column) |
|
else: |
|
ax.scatter(df[x_column], df[y_column], alpha=0.7) |
|
|
|
|
|
ax.set_xlabel(x_column) |
|
ax.set_ylabel(y_column) |
|
ax.set_title(title or f"{y_column} vs {x_column}") |
|
ax.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
img_str = self._figure_to_base64(fig) |
|
|
|
return { |
|
"chart_type": "scatter", |
|
"x_column": x_column, |
|
"y_column": y_column, |
|
"color_column": color_column, |
|
"data_points": len(df), |
|
"image": img_str |
|
} |
|
|
|
def create_histogram(self, filename: str, column: str, bins: int = 10, |
|
title: str = None) -> Dict[str, Any]: |
|
"""Create a histogram visualization.""" |
|
df = self._load_dataframe(filename) |
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size) |
|
|
|
|
|
ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black') |
|
|
|
|
|
ax.set_xlabel(column) |
|
ax.set_ylabel('Frequency') |
|
ax.set_title(title or f"Distribution of {column}") |
|
ax.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
img_str = self._figure_to_base64(fig) |
|
|
|
return { |
|
"chart_type": "histogram", |
|
"column": column, |
|
"bins": bins, |
|
"data_points": len(df), |
|
"image": img_str |
|
} |
|
|
|
def create_pie_chart(self, filename: str, label_column: str, value_column: str = None, |
|
title: str = None, limit: int = 10) -> Dict[str, Any]: |
|
"""Create a pie chart visualization.""" |
|
df = self._load_dataframe(filename) |
|
|
|
|
|
if value_column is None: |
|
data = df[label_column].value_counts().head(limit) |
|
labels = data.index.tolist() |
|
values = data.values.tolist() |
|
else: |
|
|
|
grouped = df.groupby(label_column)[value_column].sum().reset_index() |
|
|
|
grouped = grouped.nlargest(limit, value_column) |
|
labels = grouped[label_column].tolist() |
|
values = grouped[value_column].tolist() |
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size) |
|
|
|
|
|
ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True) |
|
ax.axis('equal') |
|
|
|
|
|
ax.set_title(title or f"Distribution of {label_column}") |
|
|
|
|
|
img_str = self._figure_to_base64(fig) |
|
|
|
return { |
|
"chart_type": "pie", |
|
"label_column": label_column, |
|
"value_column": value_column, |
|
"categories": len(labels), |
|
"image": img_str |
|
} |
|
|