Update indexes/query_engine.py
Browse files- indexes/query_engine.py +190 -0
indexes/query_engine.py
CHANGED
@@ -10,6 +10,82 @@ class CSVQueryEngine:
|
|
10 |
self.index_manager = index_manager
|
11 |
self.llm = llm
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def query(self, query_text: str) -> Dict[str, Any]:
|
14 |
"""Process a natural language query across CSV files."""
|
15 |
# Find relevant CSV files
|
@@ -55,10 +131,124 @@ class CSVQueryEngine:
|
|
55 |
|
56 |
return sources
|
57 |
|
|
|
58 |
def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
|
59 |
"""Handle direct statistical queries without using the LLM."""
|
60 |
query_lower = query.lower()
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# Detect query type
|
63 |
is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
|
64 |
is_max_query = "maximum" in query_lower or "max" in query_lower
|
|
|
10 |
self.index_manager = index_manager
|
11 |
self.llm = llm
|
12 |
|
13 |
+
def _prepare_context(self, query: str, csv_ids: List[str]) -> str:
|
14 |
+
"""Prepare context from relevant CSV files."""
|
15 |
+
context_parts = []
|
16 |
+
|
17 |
+
for csv_id in csv_ids:
|
18 |
+
# Get metadata
|
19 |
+
if csv_id not in self.index_manager.indexes:
|
20 |
+
continue
|
21 |
+
|
22 |
+
metadata = self.index_manager.indexes[csv_id]["metadata"]
|
23 |
+
file_path = self.index_manager.indexes[csv_id]["path"]
|
24 |
+
|
25 |
+
# Add CSV metadata
|
26 |
+
context_parts.append(f"CSV File: {metadata['filename']}")
|
27 |
+
context_parts.append(f"Columns: {', '.join(metadata['columns'])}")
|
28 |
+
context_parts.append(f"Row Count: {metadata['row_count']}")
|
29 |
+
|
30 |
+
# Add sample data
|
31 |
+
try:
|
32 |
+
df = pd.read_csv(file_path)
|
33 |
+
context_parts.append("\nSample Data:")
|
34 |
+
context_parts.append(df.head(5).to_string())
|
35 |
+
|
36 |
+
# Add some basic statistics for numeric columns
|
37 |
+
context_parts.append("\nNumeric Column Statistics:")
|
38 |
+
numeric_cols = df.select_dtypes(include=['number']).columns
|
39 |
+
for col in numeric_cols:
|
40 |
+
stats = df[col].describe()
|
41 |
+
context_parts.append(f"{col} - mean: {stats['mean']:.2f}, min: {stats['min']:.2f}, max: {stats['max']:.2f}")
|
42 |
+
|
43 |
+
# Add categorical column information
|
44 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
|
45 |
+
if len(categorical_cols) > 0:
|
46 |
+
context_parts.append("\nCategorical Column Information:")
|
47 |
+
for col in categorical_cols:
|
48 |
+
value_counts = df[col].value_counts().head(5)
|
49 |
+
context_parts.append(f"{col} - unique values: {df[col].nunique()}, top values: {', '.join(value_counts.index.astype(str))}")
|
50 |
+
|
51 |
+
# Add date information if present
|
52 |
+
date_cols = []
|
53 |
+
for col in df.columns:
|
54 |
+
try:
|
55 |
+
if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all():
|
56 |
+
date_cols.append(col)
|
57 |
+
except:
|
58 |
+
pass
|
59 |
+
|
60 |
+
if date_cols:
|
61 |
+
context_parts.append("\nDate Column Information:")
|
62 |
+
for col in date_cols:
|
63 |
+
if not pd.api.types.is_datetime64_dtype(df[col]):
|
64 |
+
df[col] = pd.to_datetime(df[col], errors='coerce')
|
65 |
+
context_parts.append(f"{col} - range: {df[col].min()} to {df[col].max()}")
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
context_parts.append(f"Error reading CSV: {str(e)}")
|
69 |
+
|
70 |
+
return "\n\n".join(context_parts)
|
71 |
+
|
72 |
+
def _generate_prompt(self, query: str, context: str) -> str:
|
73 |
+
"""Generate a prompt for the LLM."""
|
74 |
+
return f"""You are an AI assistant specialized in analyzing CSV data.
|
75 |
+
Your goal is to help users understand their data and extract insights.
|
76 |
+
|
77 |
+
Below is information about CSV files that might help answer the query:
|
78 |
+
|
79 |
+
{context}
|
80 |
+
|
81 |
+
User Query: {query}
|
82 |
+
|
83 |
+
Please provide a comprehensive and accurate answer based on the data.
|
84 |
+
If calculations are needed, explain your process.
|
85 |
+
If the data doesn't contain information to answer the query, say so clearly.
|
86 |
+
|
87 |
+
Answer:"""
|
88 |
+
|
89 |
def query(self, query_text: str) -> Dict[str, Any]:
|
90 |
"""Process a natural language query across CSV files."""
|
91 |
# Find relevant CSV files
|
|
|
131 |
|
132 |
return sources
|
133 |
|
134 |
+
|
135 |
def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
|
136 |
"""Handle direct statistical queries without using the LLM."""
|
137 |
query_lower = query.lower()
|
138 |
|
139 |
+
# Detect query type
|
140 |
+
is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
|
141 |
+
is_max_query = "maximum" in query_lower or "max" in query_lower
|
142 |
+
is_min_query = "minimum" in query_lower or "min" in query_lower
|
143 |
+
is_count_query = "count" in query_lower or "how many" in query_lower
|
144 |
+
is_unique_query = "unique" in query_lower or "distinct" in query_lower
|
145 |
+
|
146 |
+
if not (is_avg_query or is_max_query or is_min_query or is_count_query or is_unique_query):
|
147 |
+
return None # Not a statistical query
|
148 |
+
|
149 |
+
# Extract potential column names from query
|
150 |
+
query_words = set(query_lower.replace("?", "").replace(",", "").split())
|
151 |
+
|
152 |
+
for csv_id in csv_ids:
|
153 |
+
if csv_id not in self.index_manager.indexes:
|
154 |
+
continue
|
155 |
+
|
156 |
+
file_path = self.index_manager.indexes[csv_id]["path"]
|
157 |
+
metadata = self.index_manager.indexes[csv_id]["metadata"]
|
158 |
+
|
159 |
+
try:
|
160 |
+
df = pd.read_csv(file_path)
|
161 |
+
|
162 |
+
# Find relevant columns based on query
|
163 |
+
target_columns = []
|
164 |
+
for col in df.columns:
|
165 |
+
col_lower = col.lower()
|
166 |
+
# Check if column name appears in query
|
167 |
+
if any(word in col_lower for word in query_words) or col_lower in query_lower:
|
168 |
+
target_columns.append(col)
|
169 |
+
|
170 |
+
# If no direct matches but query mentions specific types of data
|
171 |
+
if not target_columns:
|
172 |
+
if any(word in query_lower for word in ["age", "old", "young"]):
|
173 |
+
age_cols = [col for col in df.columns if "age" in col.lower()]
|
174 |
+
if age_cols:
|
175 |
+
target_columns = age_cols
|
176 |
+
elif any(word in query_lower for word in ["class", "category", "type", "grade"]):
|
177 |
+
class_cols = [col for col in df.columns if any(term in col.lower()
|
178 |
+
for term in ["class", "category", "type", "grade"])]
|
179 |
+
if class_cols:
|
180 |
+
target_columns = class_cols
|
181 |
+
elif any(word in query_lower for word in ["income", "salary", "money", "price", "cost"]):
|
182 |
+
income_cols = [col for col in df.columns if any(term in col.lower()
|
183 |
+
for term in ["income", "salary", "wage", "earnings", "price", "cost"])]
|
184 |
+
if income_cols:
|
185 |
+
target_columns = income_cols
|
186 |
+
elif any(word in query_lower for word in ["date", "time", "year", "month", "day"]):
|
187 |
+
date_cols = []
|
188 |
+
for col in df.columns:
|
189 |
+
try:
|
190 |
+
if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all():
|
191 |
+
date_cols.append(col)
|
192 |
+
except:
|
193 |
+
pass
|
194 |
+
if date_cols:
|
195 |
+
target_columns = date_cols
|
196 |
+
|
197 |
+
# If still no matches, use all columns for count/unique queries,
|
198 |
+
# or numeric columns for other statistical queries
|
199 |
+
if not target_columns:
|
200 |
+
if is_count_query or is_unique_query:
|
201 |
+
target_columns = df.columns.tolist()
|
202 |
+
else:
|
203 |
+
target_columns = df.select_dtypes(include=['number']).columns.tolist()
|
204 |
+
|
205 |
+
# Perform the requested calculation
|
206 |
+
results = []
|
207 |
+
for col in target_columns:
|
208 |
+
if is_avg_query:
|
209 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
210 |
+
value = df[col].mean()
|
211 |
+
results.append(f"The average {col} is {value:.2f}")
|
212 |
+
elif is_max_query:
|
213 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
214 |
+
value = df[col].max()
|
215 |
+
results.append(f"The maximum {col} is {value}")
|
216 |
+
else:
|
217 |
+
# For non-numeric columns, show the maximum in alphabetical order
|
218 |
+
value = df[col].max()
|
219 |
+
results.append(f"The maximum (alphabetically) {col} is '{value}'")
|
220 |
+
elif is_min_query:
|
221 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
222 |
+
value = df[col].min()
|
223 |
+
results.append(f"The minimum {col} is {value}")
|
224 |
+
else:
|
225 |
+
# For non-numeric columns, show the minimum in alphabetical order
|
226 |
+
value = df[col].min()
|
227 |
+
results.append(f"The minimum (alphabetically) {col} is '{value}'")
|
228 |
+
elif is_count_query:
|
229 |
+
value = len(df)
|
230 |
+
results.append(f"The total count of rows is {value}")
|
231 |
+
elif is_unique_query:
|
232 |
+
value = df[col].nunique()
|
233 |
+
unique_values = df[col].unique()
|
234 |
+
unique_str = ", ".join(str(x) for x in unique_values[:5])
|
235 |
+
if len(unique_values) > 5:
|
236 |
+
unique_str += f", ... and {len(unique_values) - 5} more"
|
237 |
+
results.append(f"There are {value} unique values in {col}: {unique_str}")
|
238 |
+
|
239 |
+
if results:
|
240 |
+
return "\n".join(results)
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
print(f"Error processing CSV for statistical query: {e}")
|
244 |
+
|
245 |
+
return None # No results found
|
246 |
+
|
247 |
+
|
248 |
+
def _handle_statistical_query1(self, query: str, csv_ids: List[str]) -> Optional[str]:
|
249 |
+
"""Handle direct statistical queries without using the LLM."""
|
250 |
+
query_lower = query.lower()
|
251 |
+
|
252 |
# Detect query type
|
253 |
is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
|
254 |
is_max_query = "maximum" in query_lower or "max" in query_lower
|