Chamin09 commited on
Commit
df5683f
·
verified ·
1 Parent(s): fb82bd1

Update indexes/query_engine.py

Browse files
Files changed (1) hide show
  1. indexes/query_engine.py +190 -0
indexes/query_engine.py CHANGED
@@ -10,6 +10,82 @@ class CSVQueryEngine:
10
  self.index_manager = index_manager
11
  self.llm = llm
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def query(self, query_text: str) -> Dict[str, Any]:
14
  """Process a natural language query across CSV files."""
15
  # Find relevant CSV files
@@ -55,10 +131,124 @@ class CSVQueryEngine:
55
 
56
  return sources
57
 
 
58
  def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
59
  """Handle direct statistical queries without using the LLM."""
60
  query_lower = query.lower()
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # Detect query type
63
  is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
64
  is_max_query = "maximum" in query_lower or "max" in query_lower
 
10
  self.index_manager = index_manager
11
  self.llm = llm
12
 
13
+ def _prepare_context(self, query: str, csv_ids: List[str]) -> str:
14
+ """Prepare context from relevant CSV files."""
15
+ context_parts = []
16
+
17
+ for csv_id in csv_ids:
18
+ # Get metadata
19
+ if csv_id not in self.index_manager.indexes:
20
+ continue
21
+
22
+ metadata = self.index_manager.indexes[csv_id]["metadata"]
23
+ file_path = self.index_manager.indexes[csv_id]["path"]
24
+
25
+ # Add CSV metadata
26
+ context_parts.append(f"CSV File: {metadata['filename']}")
27
+ context_parts.append(f"Columns: {', '.join(metadata['columns'])}")
28
+ context_parts.append(f"Row Count: {metadata['row_count']}")
29
+
30
+ # Add sample data
31
+ try:
32
+ df = pd.read_csv(file_path)
33
+ context_parts.append("\nSample Data:")
34
+ context_parts.append(df.head(5).to_string())
35
+
36
+ # Add some basic statistics for numeric columns
37
+ context_parts.append("\nNumeric Column Statistics:")
38
+ numeric_cols = df.select_dtypes(include=['number']).columns
39
+ for col in numeric_cols:
40
+ stats = df[col].describe()
41
+ context_parts.append(f"{col} - mean: {stats['mean']:.2f}, min: {stats['min']:.2f}, max: {stats['max']:.2f}")
42
+
43
+ # Add categorical column information
44
+ categorical_cols = df.select_dtypes(include=['object', 'category']).columns
45
+ if len(categorical_cols) > 0:
46
+ context_parts.append("\nCategorical Column Information:")
47
+ for col in categorical_cols:
48
+ value_counts = df[col].value_counts().head(5)
49
+ context_parts.append(f"{col} - unique values: {df[col].nunique()}, top values: {', '.join(value_counts.index.astype(str))}")
50
+
51
+ # Add date information if present
52
+ date_cols = []
53
+ for col in df.columns:
54
+ try:
55
+ if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all():
56
+ date_cols.append(col)
57
+ except:
58
+ pass
59
+
60
+ if date_cols:
61
+ context_parts.append("\nDate Column Information:")
62
+ for col in date_cols:
63
+ if not pd.api.types.is_datetime64_dtype(df[col]):
64
+ df[col] = pd.to_datetime(df[col], errors='coerce')
65
+ context_parts.append(f"{col} - range: {df[col].min()} to {df[col].max()}")
66
+
67
+ except Exception as e:
68
+ context_parts.append(f"Error reading CSV: {str(e)}")
69
+
70
+ return "\n\n".join(context_parts)
71
+
72
+ def _generate_prompt(self, query: str, context: str) -> str:
73
+ """Generate a prompt for the LLM."""
74
+ return f"""You are an AI assistant specialized in analyzing CSV data.
75
+ Your goal is to help users understand their data and extract insights.
76
+
77
+ Below is information about CSV files that might help answer the query:
78
+
79
+ {context}
80
+
81
+ User Query: {query}
82
+
83
+ Please provide a comprehensive and accurate answer based on the data.
84
+ If calculations are needed, explain your process.
85
+ If the data doesn't contain information to answer the query, say so clearly.
86
+
87
+ Answer:"""
88
+
89
  def query(self, query_text: str) -> Dict[str, Any]:
90
  """Process a natural language query across CSV files."""
91
  # Find relevant CSV files
 
131
 
132
  return sources
133
 
134
+
135
  def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
136
  """Handle direct statistical queries without using the LLM."""
137
  query_lower = query.lower()
138
 
139
+ # Detect query type
140
+ is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
141
+ is_max_query = "maximum" in query_lower or "max" in query_lower
142
+ is_min_query = "minimum" in query_lower or "min" in query_lower
143
+ is_count_query = "count" in query_lower or "how many" in query_lower
144
+ is_unique_query = "unique" in query_lower or "distinct" in query_lower
145
+
146
+ if not (is_avg_query or is_max_query or is_min_query or is_count_query or is_unique_query):
147
+ return None # Not a statistical query
148
+
149
+ # Extract potential column names from query
150
+ query_words = set(query_lower.replace("?", "").replace(",", "").split())
151
+
152
+ for csv_id in csv_ids:
153
+ if csv_id not in self.index_manager.indexes:
154
+ continue
155
+
156
+ file_path = self.index_manager.indexes[csv_id]["path"]
157
+ metadata = self.index_manager.indexes[csv_id]["metadata"]
158
+
159
+ try:
160
+ df = pd.read_csv(file_path)
161
+
162
+ # Find relevant columns based on query
163
+ target_columns = []
164
+ for col in df.columns:
165
+ col_lower = col.lower()
166
+ # Check if column name appears in query
167
+ if any(word in col_lower for word in query_words) or col_lower in query_lower:
168
+ target_columns.append(col)
169
+
170
+ # If no direct matches but query mentions specific types of data
171
+ if not target_columns:
172
+ if any(word in query_lower for word in ["age", "old", "young"]):
173
+ age_cols = [col for col in df.columns if "age" in col.lower()]
174
+ if age_cols:
175
+ target_columns = age_cols
176
+ elif any(word in query_lower for word in ["class", "category", "type", "grade"]):
177
+ class_cols = [col for col in df.columns if any(term in col.lower()
178
+ for term in ["class", "category", "type", "grade"])]
179
+ if class_cols:
180
+ target_columns = class_cols
181
+ elif any(word in query_lower for word in ["income", "salary", "money", "price", "cost"]):
182
+ income_cols = [col for col in df.columns if any(term in col.lower()
183
+ for term in ["income", "salary", "wage", "earnings", "price", "cost"])]
184
+ if income_cols:
185
+ target_columns = income_cols
186
+ elif any(word in query_lower for word in ["date", "time", "year", "month", "day"]):
187
+ date_cols = []
188
+ for col in df.columns:
189
+ try:
190
+ if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all():
191
+ date_cols.append(col)
192
+ except:
193
+ pass
194
+ if date_cols:
195
+ target_columns = date_cols
196
+
197
+ # If still no matches, use all columns for count/unique queries,
198
+ # or numeric columns for other statistical queries
199
+ if not target_columns:
200
+ if is_count_query or is_unique_query:
201
+ target_columns = df.columns.tolist()
202
+ else:
203
+ target_columns = df.select_dtypes(include=['number']).columns.tolist()
204
+
205
+ # Perform the requested calculation
206
+ results = []
207
+ for col in target_columns:
208
+ if is_avg_query:
209
+ if pd.api.types.is_numeric_dtype(df[col]):
210
+ value = df[col].mean()
211
+ results.append(f"The average {col} is {value:.2f}")
212
+ elif is_max_query:
213
+ if pd.api.types.is_numeric_dtype(df[col]):
214
+ value = df[col].max()
215
+ results.append(f"The maximum {col} is {value}")
216
+ else:
217
+ # For non-numeric columns, show the maximum in alphabetical order
218
+ value = df[col].max()
219
+ results.append(f"The maximum (alphabetically) {col} is '{value}'")
220
+ elif is_min_query:
221
+ if pd.api.types.is_numeric_dtype(df[col]):
222
+ value = df[col].min()
223
+ results.append(f"The minimum {col} is {value}")
224
+ else:
225
+ # For non-numeric columns, show the minimum in alphabetical order
226
+ value = df[col].min()
227
+ results.append(f"The minimum (alphabetically) {col} is '{value}'")
228
+ elif is_count_query:
229
+ value = len(df)
230
+ results.append(f"The total count of rows is {value}")
231
+ elif is_unique_query:
232
+ value = df[col].nunique()
233
+ unique_values = df[col].unique()
234
+ unique_str = ", ".join(str(x) for x in unique_values[:5])
235
+ if len(unique_values) > 5:
236
+ unique_str += f", ... and {len(unique_values) - 5} more"
237
+ results.append(f"There are {value} unique values in {col}: {unique_str}")
238
+
239
+ if results:
240
+ return "\n".join(results)
241
+
242
+ except Exception as e:
243
+ print(f"Error processing CSV for statistical query: {e}")
244
+
245
+ return None # No results found
246
+
247
+
248
+ def _handle_statistical_query1(self, query: str, csv_ids: List[str]) -> Optional[str]:
249
+ """Handle direct statistical queries without using the LLM."""
250
+ query_lower = query.lower()
251
+
252
  # Detect query type
253
  is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
254
  is_max_query = "maximum" in query_lower or "max" in query_lower