Chamin09 commited on
Commit
a11e9b5
·
verified ·
1 Parent(s): 9ffd2db

Update indexes/query_engine.py

Browse files
Files changed (1) hide show
  1. indexes/query_engine.py +74 -0
indexes/query_engine.py CHANGED
@@ -38,6 +38,80 @@ class CSVQueryEngine:
38
  }
39
 
40
  def _prepare_context(self, query: str, csv_ids: List[str]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  """Prepare context from relevant CSV files."""
42
  context_parts = []
43
 
 
38
  }
39
 
40
  def _prepare_context(self, query: str, csv_ids: List[str]) -> str:
41
+ """Prepare context from relevant CSV files with pre-calculated statistics."""
42
+ context_parts = []
43
+ calculated_answers = {}
44
+
45
+ # Check for common statistical questions
46
+ query_lower = query.lower()
47
+ is_avg_question = "average" in query_lower or "mean" in query_lower
48
+ is_max_question = "maximum" in query_lower or "max" in query_lower
49
+ is_min_question = "minimum" in query_lower or "min" in query_lower
50
+
51
+ # Extract potential column names from query
52
+ query_words = set(query_lower.replace("?", "").replace(",", "").split())
53
+
54
+ for csv_id in csv_ids:
55
+ # Get metadata
56
+ if csv_id not in self.index_manager.indexes:
57
+ continue
58
+
59
+ metadata = self.index_manager.indexes[csv_id]["metadata"]
60
+ file_path = self.index_manager.indexes[csv_id]["path"]
61
+
62
+ # Add CSV metadata
63
+ context_parts.append(f"CSV File: {metadata['filename']}")
64
+ context_parts.append(f"Columns: {', '.join(metadata['columns'])}")
65
+ context_parts.append(f"Row Count: {metadata['row_count']}")
66
+
67
+ # Add sample data and calculate statistics
68
+ try:
69
+ df = pd.read_csv(file_path)
70
+ context_parts.append("\nSample Data:")
71
+ context_parts.append(df.head(3).to_string())
72
+
73
+ # Find relevant columns based on query
74
+ column_matches = []
75
+ for col in df.columns:
76
+ col_lower = col.lower()
77
+ # Check if column name appears in query or is similar to words in query
78
+ if col_lower in query_lower or any(word in col_lower for word in query_words):
79
+ column_matches.append(col)
80
+
81
+ # If no direct matches, include all numeric columns
82
+ if not column_matches:
83
+ column_matches = df.select_dtypes(include=['number']).columns.tolist()
84
+
85
+ # Calculate statistics for matched columns
86
+ for col in column_matches:
87
+ if pd.api.types.is_numeric_dtype(df[col]):
88
+ if is_avg_question:
89
+ avg_value = df[col].mean()
90
+ context_parts.append(f"\nThe average {col} is: {avg_value:.2f}")
91
+ calculated_answers[f"average_{col}"] = avg_value
92
+
93
+ if is_max_question:
94
+ max_value = df[col].max()
95
+ context_parts.append(f"\nThe maximum {col} is: {max_value}")
96
+ calculated_answers[f"max_{col}"] = max_value
97
+
98
+ if is_min_question:
99
+ min_value = df[col].min()
100
+ context_parts.append(f"\nThe minimum {col} is: {min_value}")
101
+ calculated_answers[f"min_{col}"] = min_value
102
+
103
+ except Exception as e:
104
+ context_parts.append(f"Error reading CSV: {str(e)}")
105
+
106
+ # Add direct answer if calculated
107
+ if calculated_answers:
108
+ context_parts.append("\nDirect Answer:")
109
+ for key, value in calculated_answers.items():
110
+ context_parts.append(f"{key.replace('_', ' ')}: {value}")
111
+
112
+ return "\n\n".join(context_parts)
113
+
114
+ def _prepare_context1(self, query: str, csv_ids: List[str]) -> str:
115
  """Prepare context from relevant CSV files."""
116
  context_parts = []
117