File size: 6,098 Bytes
e13d87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import Dict, List, Any, Optional, Callable
import pandas as pd
import numpy as np
from llama_index.tools import FunctionTool
from pathlib import Path

class PandasDataTools:
    """Tools for data analysis operations on CSV files."""
    
    def __init__(self, csv_directory: str):
        """Initialize with directory containing CSV files."""
        self.csv_directory = csv_directory
        self.dataframes = {}
        self.tools = self._create_tools()
    
    def _load_dataframe(self, filename: str) -> pd.DataFrame:
        """Load a CSV file as DataFrame, with caching."""
        if filename not in self.dataframes:
            file_path = Path(self.csv_directory) / filename
            if not file_path.exists() and not filename.endswith('.csv'):
                file_path = Path(self.csv_directory) / f"{filename}.csv"
            
            if file_path.exists():
                self.dataframes[filename] = pd.read_csv(file_path)
            else:
                raise ValueError(f"CSV file not found: {filename}")
                
        return self.dataframes[filename]
    
    def _create_tools(self) -> List[FunctionTool]:
        """Create LlamaIndex function tools for data operations."""
        tools = [
            FunctionTool.from_defaults(
                name="describe_csv",
                description="Get statistical description of a CSV file",
                fn=self.describe_csv
            ),
            FunctionTool.from_defaults(
                name="filter_data",
                description="Filter CSV data based on conditions",
                fn=self.filter_data
            ),
            FunctionTool.from_defaults(
                name="group_and_aggregate",
                description="Group data and calculate aggregate statistics",
                fn=self.group_and_aggregate
            ),
            FunctionTool.from_defaults(
                name="sort_data",
                description="Sort data by specified columns",
                fn=self.sort_data
            ),
            FunctionTool.from_defaults(
                name="calculate_correlation",
                description="Calculate correlation between columns",
                fn=self.calculate_correlation
            )
        ]
        return tools
    
    def get_tools(self) -> List[FunctionTool]:
        """Get all available data tools."""
        return self.tools
    
    # Tool implementations
    def describe_csv(self, filename: str) -> Dict[str, Any]:
        """Get statistical description of CSV data."""
        df = self._load_dataframe(filename)
        description = df.describe().to_dict()
        
        # Add additional info
        result = {
            "statistics": description,
            "shape": df.shape,
            "columns": df.columns.tolist(),
            "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
        }
        
        return result
    
    def filter_data(self, filename: str, column: str, condition: str, value: Any) -> Dict[str, Any]:
        """Filter data based on condition (==, >, <, >=, <=, !=, contains)."""
        df = self._load_dataframe(filename)
        
        if condition == "==":
            filtered = df[df[column] == value]
        elif condition == ">":
            filtered = df[df[column] > float(value)]
        elif condition == "<":
            filtered = df[df[column] < float(value)]
        elif condition == ">=":
            filtered = df[df[column] >= float(value)]
        elif condition == "<=":
            filtered = df[df[column] <= float(value)]
        elif condition == "!=":
            filtered = df[df[column] != value]
        elif condition.lower() == "contains":
            filtered = df[df[column].astype(str).str.contains(str(value))]
        else:
            return {"error": f"Unsupported condition: {condition}"}
        
        return {
            "result_count": len(filtered),
            "results": filtered.head(10).to_dict(orient="records"),
            "total_count": len(df)
        }
    
    def group_and_aggregate(self, filename: str, group_by: str, agg_column: str, 

                           agg_function: str = "mean") -> Dict[str, Any]:
        """Group by column and calculate aggregate statistic."""
        df = self._load_dataframe(filename)
        
        agg_functions = {
            "mean": np.mean,
            "sum": np.sum,
            "min": np.min,
            "max": np.max,
            "count": len,
            "median": np.median
        }
        
        if agg_function not in agg_functions:
            return {"error": f"Unsupported aggregation function: {agg_function}"}
        
        grouped = df.groupby(group_by)[agg_column].agg(agg_functions[agg_function])
        
        return {
            "group_by": group_by,
            "aggregated_column": agg_column,
            "aggregation": agg_function,
            "results": grouped.to_dict()
        }
    
    def sort_data(self, filename: str, sort_by: str, ascending: bool = True) -> Dict[str, Any]:
        """Sort data by column."""
        df = self._load_dataframe(filename)
        
        sorted_df = df.sort_values(by=sort_by, ascending=ascending)
        
        return {
            "sorted_by": sort_by,
            "ascending": ascending,
            "results": sorted_df.head(10).to_dict(orient="records")
        }
    
    def calculate_correlation(self, filename: str, column1: str, column2: str) -> Dict[str, Any]:
        """Calculate correlation between two columns."""
        df = self._load_dataframe(filename)
        
        try:
            correlation = df[column1].corr(df[column2])
            return {
                "correlation": correlation,
                "column1": column1,
                "column2": column2
            }
        except Exception as e:
            return {"error": f"Could not calculate correlation: {str(e)}"}