|
import os |
|
import pandas as pd |
|
import json |
|
import glob |
|
from smolagents import tool |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from pathlib import Path |
|
import numpy as np |
|
|
|
@tool |
|
def load_previous_dataset() -> pd.DataFrame: |
|
""" |
|
Load the dataset that was used in the previous analysis. |
|
|
|
Returns: |
|
The pandas DataFrame that was used in the previous report generation |
|
""" |
|
try: |
|
|
|
dataset_files = glob.glob('generated_data/*dataset*.csv') + glob.glob('generated_data/*data*.csv') |
|
|
|
if not dataset_files: |
|
|
|
csv_files = glob.glob('generated_data/*.csv') |
|
if csv_files: |
|
dataset_files = csv_files |
|
|
|
if not dataset_files: |
|
raise Exception("No dataset found in generated_data folder") |
|
|
|
|
|
latest_file = max(dataset_files, key=os.path.getctime) |
|
df = pd.read_csv(latest_file) |
|
|
|
print(f"✅ Loaded dataset from {latest_file} with {len(df)} rows and {len(df.columns)} columns") |
|
return df |
|
|
|
except Exception as e: |
|
raise Exception(f"Error loading previous dataset: {str(e)}") |
|
|
|
@tool |
|
def get_dataset_summary(df: pd.DataFrame) -> str: |
|
""" |
|
Get a comprehensive summary of the dataset structure and content. |
|
|
|
Args: |
|
df: The pandas DataFrame to analyze |
|
|
|
Returns: |
|
A formatted string with dataset summary information |
|
""" |
|
try: |
|
summary_lines = [] |
|
summary_lines.append("=== DATASET SUMMARY ===") |
|
summary_lines.append(f"Shape: {df.shape[0]} rows × {df.shape[1]} columns") |
|
summary_lines.append("") |
|
|
|
summary_lines.append("Column Information:") |
|
for col in df.columns: |
|
dtype = str(df[col].dtype) |
|
non_null = df[col].count() |
|
null_count = df[col].isnull().sum() |
|
unique_count = df[col].nunique() |
|
|
|
summary_lines.append(f" • {col}: {dtype}, {non_null} non-null, {null_count} null, {unique_count} unique") |
|
|
|
|
|
if df[col].dtype == 'object' and unique_count <= 10: |
|
sample_values = df[col].value_counts().head(5).index.tolist() |
|
summary_lines.append(f" Sample values: {sample_values}") |
|
|
|
summary_lines.append("") |
|
summary_lines.append("First 3 rows:") |
|
summary_lines.append(df.head(3).to_string()) |
|
|
|
return "\n".join(summary_lines) |
|
|
|
except Exception as e: |
|
return f"Error analyzing dataset: {str(e)}" |
|
|
|
@tool |
|
def create_followup_visualization(df: pd.DataFrame, chart_type: str, x_column: str, y_column: str = None, title: str = "Follow-up Analysis", filename: str = "followup_chart.png") -> str: |
|
""" |
|
Create a visualization for follow-up analysis. |
|
|
|
Args: |
|
df: The pandas DataFrame to visualize |
|
chart_type: Type of chart ('bar', 'line', 'scatter', 'histogram', 'box', 'pie') |
|
x_column: Column name for x-axis |
|
y_column: Column name for y-axis (optional for some chart types) |
|
title: Title for the chart |
|
filename: Name of the file to save (should end with .png) |
|
|
|
Returns: |
|
Path to the saved visualization file |
|
""" |
|
try: |
|
plt.figure(figsize=(12, 8)) |
|
|
|
if chart_type == 'bar': |
|
if y_column: |
|
df_grouped = df.groupby(x_column)[y_column].sum().sort_values(ascending=False) |
|
plt.bar(range(len(df_grouped)), df_grouped.values) |
|
plt.xticks(range(len(df_grouped)), df_grouped.index, rotation=45) |
|
plt.ylabel(y_column) |
|
else: |
|
value_counts = df[x_column].value_counts().head(10) |
|
plt.bar(range(len(value_counts)), value_counts.values) |
|
plt.xticks(range(len(value_counts)), value_counts.index, rotation=45) |
|
plt.ylabel('Count') |
|
|
|
elif chart_type == 'line': |
|
if y_column: |
|
df_sorted = df.sort_values(x_column) |
|
plt.plot(df_sorted[x_column], df_sorted[y_column]) |
|
plt.ylabel(y_column) |
|
else: |
|
value_counts = df[x_column].value_counts().sort_index() |
|
plt.plot(value_counts.index, value_counts.values) |
|
plt.ylabel('Count') |
|
|
|
elif chart_type == 'scatter': |
|
if y_column: |
|
plt.scatter(df[x_column], df[y_column], alpha=0.6) |
|
plt.ylabel(y_column) |
|
else: |
|
raise ValueError("Scatter plot requires both x_column and y_column") |
|
|
|
elif chart_type == 'histogram': |
|
plt.hist(df[x_column], bins=30, alpha=0.7) |
|
plt.ylabel('Frequency') |
|
|
|
elif chart_type == 'box': |
|
if y_column: |
|
df.boxplot(column=y_column, by=x_column) |
|
else: |
|
plt.boxplot(df[x_column]) |
|
plt.ylabel(x_column) |
|
|
|
elif chart_type == 'pie': |
|
value_counts = df[x_column].value_counts().head(10) |
|
plt.pie(value_counts.values, labels=value_counts.index, autopct='%1.1f%%') |
|
|
|
else: |
|
raise ValueError(f"Unsupported chart type: {chart_type}") |
|
|
|
plt.xlabel(x_column) |
|
plt.title(title) |
|
plt.tight_layout() |
|
|
|
|
|
if not filename.endswith('.png'): |
|
filename += '.png' |
|
|
|
filepath = os.path.join('generated_data', filename) |
|
plt.savefig(filepath, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
return f"Visualization saved to: {filepath}" |
|
|
|
except Exception as e: |
|
plt.close() |
|
return f"Error creating visualization: {str(e)}" |
|
|
|
@tool |
|
def get_previous_report_content() -> str: |
|
""" |
|
Get the content of the previously generated report. |
|
|
|
Returns: |
|
The text content of the previous report for context |
|
""" |
|
try: |
|
|
|
report_files = glob.glob('generated_data/*.docx') |
|
|
|
if not report_files: |
|
return "No previous report found in generated_data folder" |
|
|
|
|
|
latest_report = max(report_files, key=os.path.getctime) |
|
|
|
|
|
docx_content = "" |
|
try: |
|
from docx import Document |
|
doc = Document(latest_report) |
|
paragraphs = [] |
|
for para in doc.paragraphs: |
|
if para.text.strip(): |
|
paragraphs.append(para.text.strip()) |
|
docx_content = "\n".join(paragraphs[:10]) |
|
except Exception as e: |
|
docx_content = f"Could not extract text from DOCX: {str(e)}" |
|
|
|
file_size = os.path.getsize(latest_report) |
|
|
|
|
|
text_files = glob.glob('generated_data/*.txt') |
|
text_content = "" |
|
|
|
if text_files: |
|
latest_text = max(text_files, key=os.path.getctime) |
|
with open(latest_text, 'r', encoding='utf-8') as f: |
|
text_content = f.read() |
|
|
|
summary = f"""=== PREVIOUS REPORT CONTEXT === |
|
Report file: {latest_report} |
|
File size: {file_size} bytes |
|
Created: {os.path.getctime(latest_report)} |
|
|
|
DOCX Report Content (first 10 paragraphs): |
|
{docx_content} |
|
|
|
Additional analysis content: |
|
{text_content if text_content else 'No additional text content found'} |
|
|
|
The report was generated from the dataset in the previous analysis. |
|
You can use load_previous_dataset() to access the same data. |
|
""" |
|
|
|
return summary |
|
|
|
except Exception as e: |
|
return f"Error accessing previous report: {str(e)}" |
|
|
|
@tool |
|
def analyze_column_correlation(df: pd.DataFrame, column1: str, column2: str) -> str: |
|
""" |
|
Analyze correlation between two columns in the dataset. |
|
|
|
Args: |
|
df: The pandas DataFrame |
|
column1: First column name |
|
column2: Second column name |
|
|
|
Returns: |
|
Correlation analysis results |
|
""" |
|
try: |
|
|
|
if column1 not in df.columns or column2 not in df.columns: |
|
return f"Error: One or both columns not found. Available columns: {list(df.columns)}" |
|
|
|
|
|
try: |
|
col1_numeric = pd.to_numeric(df[column1], errors='coerce') |
|
col2_numeric = pd.to_numeric(df[column2], errors='coerce') |
|
except: |
|
return f"Error: Cannot convert columns to numeric for correlation analysis" |
|
|
|
|
|
correlation = col1_numeric.corr(col2_numeric) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.scatter(col1_numeric, col2_numeric, alpha=0.6) |
|
plt.xlabel(column1) |
|
plt.ylabel(column2) |
|
plt.title(f'Correlation between {column1} and {column2}\nCorrelation coefficient: {correlation:.3f}') |
|
|
|
|
|
if not col1_numeric.isna().all() and not col2_numeric.isna().all(): |
|
z = np.polyfit(col1_numeric.dropna(), col2_numeric.dropna(), 1) |
|
p = np.poly1d(z) |
|
plt.plot(col1_numeric, p(col1_numeric), "r--", alpha=0.8) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
filename = f'correlation_{column1}_{column2}.png' |
|
filepath = os.path.join('generated_data', filename) |
|
plt.savefig(filepath, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
|
|
if abs(correlation) > 0.7: |
|
strength = "strong" |
|
elif abs(correlation) > 0.4: |
|
strength = "moderate" |
|
elif abs(correlation) > 0.2: |
|
strength = "weak" |
|
else: |
|
strength = "very weak" |
|
|
|
direction = "positive" if correlation > 0 else "negative" |
|
|
|
result = f"""=== CORRELATION ANALYSIS === |
|
Columns: {column1} vs {column2} |
|
Correlation coefficient: {correlation:.3f} |
|
Strength: {strength} {direction} correlation |
|
|
|
Interpretation: |
|
- The correlation is {strength} and {direction} |
|
- Values closer to 1 or -1 indicate stronger linear relationships |
|
- Values closer to 0 indicate weaker linear relationships |
|
|
|
Visualization saved to: {filepath} |
|
""" |
|
|
|
return result |
|
|
|
except Exception as e: |
|
return f"Error in correlation analysis: {str(e)}" |
|
|
|
@tool |
|
def create_statistical_summary(df: pd.DataFrame, column_name: str) -> str: |
|
""" |
|
Create a comprehensive statistical summary with visualization for a specific column. |
|
|
|
Args: |
|
df: The pandas DataFrame |
|
column_name: Name of the column to analyze |
|
|
|
Returns: |
|
Statistical summary and saves a visualization |
|
""" |
|
try: |
|
if column_name not in df.columns: |
|
return f"Error: Column '{column_name}' not found. Available columns: {list(df.columns)}" |
|
|
|
column_data = df[column_name] |
|
|
|
|
|
summary_lines = [f"=== STATISTICAL SUMMARY: {column_name} ==="] |
|
|
|
if pd.api.types.is_numeric_dtype(column_data): |
|
|
|
stats = column_data.describe() |
|
summary_lines.extend([ |
|
f"Count: {stats['count']:.0f}", |
|
f"Mean: {stats['mean']:.2f}", |
|
f"Median: {stats['50%']:.2f}", |
|
f"Standard Deviation: {stats['std']:.2f}", |
|
f"Min: {stats['min']:.2f}", |
|
f"Max: {stats['max']:.2f}", |
|
f"25th Percentile: {stats['25%']:.2f}", |
|
f"75th Percentile: {stats['75%']:.2f}", |
|
]) |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
|
|
|
ax1.hist(column_data.dropna(), bins=30, alpha=0.7, color='skyblue', edgecolor='black') |
|
ax1.set_title(f'Distribution of {column_name}') |
|
ax1.set_xlabel(column_name) |
|
ax1.set_ylabel('Frequency') |
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
ax2.boxplot(column_data.dropna()) |
|
ax2.set_title(f'Box Plot of {column_name}') |
|
ax2.set_ylabel(column_name) |
|
ax2.grid(True, alpha=0.3) |
|
|
|
else: |
|
|
|
value_counts = column_data.value_counts() |
|
summary_lines.extend([ |
|
f"Total unique values: {column_data.nunique()}", |
|
f"Most frequent value: {value_counts.index[0]} ({value_counts.iloc[0]} times)", |
|
f"Least frequent value: {value_counts.index[-1]} ({value_counts.iloc[-1]} times)", |
|
"", |
|
"Top 10 values:" |
|
]) |
|
|
|
for value, count in value_counts.head(10).items(): |
|
percentage = (count / len(column_data)) * 100 |
|
summary_lines.append(f" {value}: {count} ({percentage:.1f}%)") |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
|
|
|
top_values = value_counts.head(10) |
|
ax1.bar(range(len(top_values)), top_values.values, color='lightcoral') |
|
ax1.set_title(f'Top 10 Values in {column_name}') |
|
ax1.set_xlabel('Categories') |
|
ax1.set_ylabel('Count') |
|
ax1.set_xticks(range(len(top_values))) |
|
ax1.set_xticklabels(top_values.index, rotation=45, ha='right') |
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
top_8 = value_counts.head(8) |
|
others_count = value_counts.iloc[8:].sum() if len(value_counts) > 8 else 0 |
|
|
|
if others_count > 0: |
|
pie_data = list(top_8.values) + [others_count] |
|
pie_labels = list(top_8.index) + ['Others'] |
|
else: |
|
pie_data = top_8.values |
|
pie_labels = top_8.index |
|
|
|
ax2.pie(pie_data, labels=pie_labels, autopct='%1.1f%%', startangle=90) |
|
ax2.set_title(f'Distribution of {column_name}') |
|
|
|
plt.tight_layout() |
|
|
|
|
|
filename = f'statistical_summary_{column_name}.png' |
|
filepath = os.path.join('generated_data', filename) |
|
plt.savefig(filepath, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
summary_lines.append(f"\nVisualization saved to: {filepath}") |
|
|
|
return "\n".join(summary_lines) |
|
|
|
except Exception as e: |
|
return f"Error in statistical analysis: {str(e)}" |
|
|
|
@tool |
|
def filter_and_visualize_data(df: pd.DataFrame, filter_column: str, filter_value: str, analysis_column: str, chart_type: str = "bar") -> str: |
|
""" |
|
Filter the dataset and create a visualization of the filtered data. |
|
|
|
Args: |
|
df: The pandas DataFrame |
|
filter_column: Column to filter by |
|
filter_value: Value to filter for (can be partial match for string columns) |
|
analysis_column: Column to analyze in the filtered data |
|
chart_type: Type of chart to create ('bar', 'line', 'histogram', 'pie') |
|
|
|
Returns: |
|
Analysis results and saves a visualization |
|
""" |
|
try: |
|
if filter_column not in df.columns: |
|
return f"Error: Filter column '{filter_column}' not found. Available columns: {list(df.columns)}" |
|
|
|
if analysis_column not in df.columns: |
|
return f"Error: Analysis column '{analysis_column}' not found. Available columns: {list(df.columns)}" |
|
|
|
|
|
if df[filter_column].dtype == 'object': |
|
|
|
filtered_df = df[df[filter_column].str.contains(filter_value, case=False, na=False)] |
|
else: |
|
|
|
try: |
|
filter_value_numeric = float(filter_value) |
|
filtered_df = df[df[filter_column] == filter_value_numeric] |
|
except ValueError: |
|
return f"Error: Cannot convert '{filter_value}' to numeric for filtering" |
|
|
|
if filtered_df.empty: |
|
return f"No data found matching filter: {filter_column} = '{filter_value}'" |
|
|
|
result_lines = [ |
|
f"=== FILTERED DATA ANALYSIS ===", |
|
f"Filter: {filter_column} contains/equals '{filter_value}'", |
|
f"Filtered dataset size: {len(filtered_df)} rows (from {len(df)} total)", |
|
f"Analysis column: {analysis_column}", |
|
"" |
|
] |
|
|
|
|
|
analysis_data = filtered_df[analysis_column] |
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
if chart_type == "bar": |
|
if pd.api.types.is_numeric_dtype(analysis_data): |
|
|
|
analysis_data.hist(bins=20, alpha=0.7, color='lightblue', edgecolor='black') |
|
plt.ylabel('Frequency') |
|
else: |
|
|
|
value_counts = analysis_data.value_counts().head(15) |
|
plt.bar(range(len(value_counts)), value_counts.values, color='lightcoral') |
|
plt.xticks(range(len(value_counts)), value_counts.index, rotation=45, ha='right') |
|
plt.ylabel('Count') |
|
|
|
|
|
result_lines.extend([ |
|
f"Top value: {value_counts.index[0]} ({value_counts.iloc[0]} occurrences)", |
|
f"Total unique values: {analysis_data.nunique()}" |
|
]) |
|
|
|
elif chart_type == "line": |
|
if pd.api.types.is_numeric_dtype(analysis_data): |
|
sorted_data = analysis_data.sort_values() |
|
plt.plot(range(len(sorted_data)), sorted_data.values, marker='o', alpha=0.7) |
|
plt.ylabel(analysis_column) |
|
plt.xlabel('Sorted Index') |
|
else: |
|
return "Line chart requires numeric data for analysis column" |
|
|
|
elif chart_type == "histogram": |
|
if pd.api.types.is_numeric_dtype(analysis_data): |
|
plt.hist(analysis_data.dropna(), bins=30, alpha=0.7, color='green', edgecolor='black') |
|
plt.ylabel('Frequency') |
|
|
|
|
|
mean_val = analysis_data.mean() |
|
median_val = analysis_data.median() |
|
result_lines.extend([ |
|
f"Mean: {mean_val:.2f}", |
|
f"Median: {median_val:.2f}", |
|
f"Standard Deviation: {analysis_data.std():.2f}" |
|
]) |
|
else: |
|
return "Histogram requires numeric data for analysis column" |
|
|
|
elif chart_type == "pie": |
|
value_counts = analysis_data.value_counts().head(10) |
|
plt.pie(value_counts.values, labels=value_counts.index, autopct='%1.1f%%', startangle=90) |
|
|
|
plt.title(f'{chart_type.title()} Chart: {analysis_column}\nFiltered by {filter_column} = "{filter_value}"') |
|
plt.xlabel(analysis_column) |
|
plt.tight_layout() |
|
|
|
|
|
filename = f'filtered_{filter_column}_{filter_value}_{analysis_column}_{chart_type}.png' |
|
|
|
filename = "".join(c for c in filename if c.isalnum() or c in ('_', '-', '.')).rstrip() |
|
filepath = os.path.join('generated_data', filename) |
|
plt.savefig(filepath, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
result_lines.append(f"\nVisualization saved to: {filepath}") |
|
|
|
return "\n".join(result_lines) |
|
|
|
except Exception as e: |
|
return f"Error in filtered analysis: {str(e)}" |