Upload 12 files
Browse files- .gitattributes +0 -35
- app.py +217 -0
- indexes/csv_index_builder.py +51 -0
- indexes/index_manager.py +104 -0
- indexes/query_engine.py +124 -0
- models/llm_setup.py +64 -0
- requirements.txt +14 -0
- tools/data_tools.py +158 -0
- tools/export.py +161 -0
- tools/visualization.py +246 -0
- utils/csv_helper.py +189 -0
- utils/prompt_template.py +0 -0
.gitattributes
CHANGED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import tempfile
|
4 |
+
from pathlib import Path
|
5 |
+
import base64
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
import time
|
9 |
+
|
10 |
+
# Import our components
|
11 |
+
from models.llm_setup import setup_llm
|
12 |
+
from indexes.csv_index_builder import EnhancedCSVReader
|
13 |
+
from indexes.index_manager import CSVIndexManager
|
14 |
+
from indexes.query_engine import CSVQueryEngine
|
15 |
+
from tools.data_tools import PandasDataTools
|
16 |
+
from tools.visualization import VisualizationTools
|
17 |
+
from tools.export import ExportTools
|
18 |
+
|
19 |
+
# Setup temporary directory for uploaded files
|
20 |
+
UPLOAD_DIR = Path(tempfile.mkdtemp())
|
21 |
+
EXPORT_DIR = Path(tempfile.mkdtemp())
|
22 |
+
|
23 |
+
class CSVChatApp:
|
24 |
+
"""Main application class for CSV chatbot."""
|
25 |
+
|
26 |
+
def __init__(self):
|
27 |
+
"""Initialize the application components."""
|
28 |
+
# Initialize the language model
|
29 |
+
self.llm = setup_llm()
|
30 |
+
|
31 |
+
# Initialize the index manager
|
32 |
+
self.index_manager = CSVIndexManager()
|
33 |
+
|
34 |
+
# Initialize tools
|
35 |
+
self.data_tools = PandasDataTools(str(UPLOAD_DIR))
|
36 |
+
self.viz_tools = VisualizationTools(str(UPLOAD_DIR))
|
37 |
+
self.export_tools = ExportTools(str(EXPORT_DIR))
|
38 |
+
|
39 |
+
# Initialize query engine with tools
|
40 |
+
self.query_engine = self._setup_query_engine()
|
41 |
+
|
42 |
+
# Track conversation history
|
43 |
+
self.chat_history = []
|
44 |
+
self.uploaded_files = []
|
45 |
+
|
46 |
+
def _setup_query_engine(self):
|
47 |
+
"""Set up the query engine with tools."""
|
48 |
+
# Get all tools
|
49 |
+
tools = (
|
50 |
+
self.data_tools.get_tools() +
|
51 |
+
self.viz_tools.get_tools() +
|
52 |
+
self.export_tools.get_tools()
|
53 |
+
)
|
54 |
+
|
55 |
+
# Create query engine with tools
|
56 |
+
query_engine = CSVQueryEngine(self.index_manager, self.llm)
|
57 |
+
|
58 |
+
return query_engine
|
59 |
+
|
60 |
+
def handle_file_upload(self, files):
|
61 |
+
"""Process uploaded CSV files."""
|
62 |
+
file_info = []
|
63 |
+
|
64 |
+
for file in files:
|
65 |
+
if file is None:
|
66 |
+
continue
|
67 |
+
|
68 |
+
# Get file path
|
69 |
+
file_path = Path(file.name)
|
70 |
+
|
71 |
+
# Only process CSV files
|
72 |
+
if not file_path.suffix.lower() == '.csv':
|
73 |
+
continue
|
74 |
+
|
75 |
+
# Copy to upload directory
|
76 |
+
dest_path = UPLOAD_DIR / file_path.name
|
77 |
+
with open(dest_path, 'wb') as f:
|
78 |
+
f.write(file_path.read_bytes())
|
79 |
+
|
80 |
+
# Create index for this file
|
81 |
+
try:
|
82 |
+
self.index_manager.create_index(str(dest_path))
|
83 |
+
file_info.append(f"✅ Indexed: {file_path.name}")
|
84 |
+
self.uploaded_files.append(str(dest_path))
|
85 |
+
except Exception as e:
|
86 |
+
file_info.append(f"❌ Failed to index {file_path.name}: {str(e)}")
|
87 |
+
|
88 |
+
# Return information about processed files
|
89 |
+
if file_info:
|
90 |
+
return "\n".join(file_info)
|
91 |
+
else:
|
92 |
+
return "No CSV files were uploaded."
|
93 |
+
|
94 |
+
def process_query(self, query, history):
|
95 |
+
"""Process a user query and generate a response."""
|
96 |
+
if not self.uploaded_files:
|
97 |
+
return "Please upload CSV files before asking questions."
|
98 |
+
|
99 |
+
# Add user message to history
|
100 |
+
self.chat_history.append({"role": "user", "content": query})
|
101 |
+
|
102 |
+
# Process the query
|
103 |
+
try:
|
104 |
+
response = self.query_engine.query(query)
|
105 |
+
answer = response["answer"]
|
106 |
+
|
107 |
+
# Check if response contains an image
|
108 |
+
if isinstance(answer, dict) and "image" in answer:
|
109 |
+
# Handle image in response
|
110 |
+
img_data = answer["image"]
|
111 |
+
img = Image.open(io.BytesIO(base64.b64decode(img_data)))
|
112 |
+
img_path = EXPORT_DIR / f"viz_{int(time.time())}.png"
|
113 |
+
img.save(img_path)
|
114 |
+
|
115 |
+
# Update answer to include image path
|
116 |
+
text_response = answer.get("text", "Generated visualization")
|
117 |
+
answer = (text_response, str(img_path))
|
118 |
+
|
119 |
+
# Add assistant message to history
|
120 |
+
self.chat_history.append({"role": "assistant", "content": answer})
|
121 |
+
|
122 |
+
return answer
|
123 |
+
|
124 |
+
except Exception as e:
|
125 |
+
error_msg = f"Error processing query: {str(e)}"
|
126 |
+
self.chat_history.append({"role": "assistant", "content": error_msg})
|
127 |
+
return error_msg
|
128 |
+
|
129 |
+
def export_conversation(self):
|
130 |
+
"""Export the conversation as a report."""
|
131 |
+
if not self.chat_history:
|
132 |
+
return "No conversation to export."
|
133 |
+
|
134 |
+
# Extract content for report
|
135 |
+
title = "CSV Chat Conversation Report"
|
136 |
+
content = ""
|
137 |
+
images = []
|
138 |
+
|
139 |
+
for msg in self.chat_history:
|
140 |
+
role = msg["role"]
|
141 |
+
content_text = msg["content"]
|
142 |
+
|
143 |
+
# Handle content that might contain images
|
144 |
+
if isinstance(content_text, tuple) and len(content_text) == 2:
|
145 |
+
text, img_path = content_text
|
146 |
+
content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {text}"
|
147 |
+
|
148 |
+
# Add image to report
|
149 |
+
try:
|
150 |
+
with open(img_path, "rb") as img_file:
|
151 |
+
img_data = base64.b64encode(img_file.read()).decode('utf-8')
|
152 |
+
images.append(img_data)
|
153 |
+
except Exception:
|
154 |
+
pass
|
155 |
+
else:
|
156 |
+
content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {content_text}"
|
157 |
+
|
158 |
+
# Generate report
|
159 |
+
result = self.export_tools.generate_report(title, content, images)
|
160 |
+
|
161 |
+
if result["success"]:
|
162 |
+
return f"Report exported to: {result['report_path']}"
|
163 |
+
else:
|
164 |
+
return "Failed to export report."
|
165 |
+
|
166 |
+
# Create the Gradio interface
|
167 |
+
def create_interface():
|
168 |
+
"""Create the Gradio web interface."""
|
169 |
+
app = CSVChatApp()
|
170 |
+
|
171 |
+
with gr.Blocks(title="CSV Chat Assistant") as interface:
|
172 |
+
gr.Markdown("# CSV Chat Assistant")
|
173 |
+
gr.Markdown("Upload CSV files and ask questions in natural language.")
|
174 |
+
|
175 |
+
with gr.Row():
|
176 |
+
with gr.Column(scale=1):
|
177 |
+
file_upload = gr.File(
|
178 |
+
label="Upload CSV Files",
|
179 |
+
file_count="multiple",
|
180 |
+
type="file"
|
181 |
+
)
|
182 |
+
upload_button = gr.Button("Process Files")
|
183 |
+
file_status = gr.Textbox(label="File Status")
|
184 |
+
|
185 |
+
export_button = gr.Button("Export Conversation")
|
186 |
+
export_status = gr.Textbox(label="Export Status")
|
187 |
+
|
188 |
+
with gr.Column(scale=2):
|
189 |
+
chatbot = gr.Chatbot(label="Conversation")
|
190 |
+
msg = gr.Textbox(label="Your Question")
|
191 |
+
submit_button = gr.Button("Submit")
|
192 |
+
|
193 |
+
# Set up event handlers
|
194 |
+
upload_button.click(
|
195 |
+
fn=app.handle_file_upload,
|
196 |
+
inputs=[file_upload],
|
197 |
+
outputs=[file_status]
|
198 |
+
)
|
199 |
+
|
200 |
+
submit_button.click(
|
201 |
+
fn=app.process_query,
|
202 |
+
inputs=[msg, chatbot],
|
203 |
+
outputs=[chatbot]
|
204 |
+
)
|
205 |
+
|
206 |
+
export_button.click(
|
207 |
+
fn=app.export_conversation,
|
208 |
+
inputs=[],
|
209 |
+
outputs=[export_status]
|
210 |
+
)
|
211 |
+
|
212 |
+
return interface
|
213 |
+
|
214 |
+
# Launch the app
|
215 |
+
if __name__ == "__main__":
|
216 |
+
interface = create_interface()
|
217 |
+
interface.launch()
|
indexes/csv_index_builder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
from pathlib import Path
|
3 |
+
import pandas as pd
|
4 |
+
from llama_index.readers.file import CSVReader
|
5 |
+
from llama_index.schema import Document
|
6 |
+
|
7 |
+
class EnhancedCSVReader:
|
8 |
+
"""Enhanced CSV reader with metadata extraction capabilities."""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
self.csv_reader = CSVReader()
|
12 |
+
|
13 |
+
def load_data(self, file_path: str) -> List[Document]:
|
14 |
+
"""Load CSV file and extract documents with metadata."""
|
15 |
+
# Load the CSV file
|
16 |
+
documents = self.csv_reader.load_data(file_path)
|
17 |
+
|
18 |
+
# Extract and add metadata
|
19 |
+
csv_metadata = self._extract_metadata(file_path)
|
20 |
+
|
21 |
+
# Enhance documents with metadata
|
22 |
+
for doc in documents:
|
23 |
+
doc.metadata.update(csv_metadata)
|
24 |
+
|
25 |
+
return documents
|
26 |
+
|
27 |
+
def _extract_metadata(self, file_path: str) -> Dict:
|
28 |
+
"""Extract useful metadata from CSV file."""
|
29 |
+
df = pd.read_csv(file_path)
|
30 |
+
filename = Path(file_path).name
|
31 |
+
|
32 |
+
# Extract column information
|
33 |
+
columns = df.columns.tolist()
|
34 |
+
dtypes = {col: str(df[col].dtype) for col in columns}
|
35 |
+
|
36 |
+
# Extract sample values (first 3 non-null values per column)
|
37 |
+
samples = {}
|
38 |
+
for col in columns:
|
39 |
+
non_null_values = df[col].dropna().head(3).tolist()
|
40 |
+
samples[col] = [str(val) for val in non_null_values]
|
41 |
+
|
42 |
+
# Basic statistics
|
43 |
+
row_count = len(df)
|
44 |
+
|
45 |
+
return {
|
46 |
+
"filename": filename,
|
47 |
+
"columns": columns,
|
48 |
+
"dtypes": dtypes,
|
49 |
+
"samples": samples,
|
50 |
+
"row_count": row_count
|
51 |
+
}
|
indexes/index_manager.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
|
5 |
+
from llama_index import VectorStoreIndex, StorageContext
|
6 |
+
from llama_index.vector_stores import ChromaVectorStore
|
7 |
+
from llama_index.embeddings import HuggingFaceEmbedding
|
8 |
+
import chromadb
|
9 |
+
|
10 |
+
from indexes.csv_index_builder import EnhancedCSVReader
|
11 |
+
|
12 |
+
class CSVIndexManager:
|
13 |
+
"""Manages creation and retrieval of indexes for CSV files."""
|
14 |
+
|
15 |
+
def __init__(self, embedding_model_name: str = "all-MiniLM-L6-v2"):
|
16 |
+
self.csv_reader = EnhancedCSVReader()
|
17 |
+
self.embed_model = HuggingFaceEmbedding(model_name=embedding_model_name)
|
18 |
+
self.chroma_client = chromadb.Client()
|
19 |
+
self.indexes = {}
|
20 |
+
|
21 |
+
def create_index(self, file_path: str) -> VectorStoreIndex:
|
22 |
+
"""Create vector index for a CSV file."""
|
23 |
+
# Extract filename as identifier
|
24 |
+
file_id = Path(file_path).stem
|
25 |
+
|
26 |
+
# Load documents with metadata
|
27 |
+
documents = self.csv_reader.load_data(file_path)
|
28 |
+
|
29 |
+
# Create Chroma collection for this CSV
|
30 |
+
chroma_collection = self.chroma_client.create_collection(file_id)
|
31 |
+
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
32 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
33 |
+
|
34 |
+
# Create vector index with our embedding model
|
35 |
+
index = VectorStoreIndex.from_documents(
|
36 |
+
documents,
|
37 |
+
storage_context=storage_context,
|
38 |
+
embed_model=self.embed_model
|
39 |
+
)
|
40 |
+
|
41 |
+
# Store in our registry
|
42 |
+
self.indexes[file_id] = {
|
43 |
+
"index": index,
|
44 |
+
"metadata": documents[0].metadata if documents else {}
|
45 |
+
}
|
46 |
+
|
47 |
+
return index
|
48 |
+
|
49 |
+
def index_directory(self, directory_path: str) -> Dict[str, VectorStoreIndex]:
|
50 |
+
"""Index all CSV files in a directory."""
|
51 |
+
indexed_files = {}
|
52 |
+
|
53 |
+
# Get all CSV files in directory
|
54 |
+
csv_files = [f for f in os.listdir(directory_path)
|
55 |
+
if f.lower().endswith('.csv')]
|
56 |
+
|
57 |
+
# Create index for each CSV file
|
58 |
+
for csv_file in csv_files:
|
59 |
+
file_path = os.path.join(directory_path, csv_file)
|
60 |
+
file_id = Path(file_path).stem
|
61 |
+
index = self.create_index(file_path)
|
62 |
+
indexed_files[file_id] = index
|
63 |
+
|
64 |
+
return indexed_files
|
65 |
+
|
66 |
+
def find_relevant_csvs(self, query: str, top_k: int = 3) -> List[str]:
|
67 |
+
"""Find most relevant CSV files for a given query."""
|
68 |
+
if not self.indexes:
|
69 |
+
return []
|
70 |
+
|
71 |
+
# Create a document from the query
|
72 |
+
query_embedding = self.embed_model.get_text_embedding(query)
|
73 |
+
|
74 |
+
# Calculate similarity with each CSV's metadata
|
75 |
+
similarities = {}
|
76 |
+
for file_id, index_info in self.indexes.items():
|
77 |
+
# Get metadata description
|
78 |
+
metadata = index_info["metadata"]
|
79 |
+
|
80 |
+
# Create a rich description of the CSV
|
81 |
+
csv_description = f"CSV file {metadata['filename']} with columns: {', '.join(metadata['columns'])}. "
|
82 |
+
csv_description += f"Contains {metadata['row_count']} rows. "
|
83 |
+
csv_description += "Sample data: "
|
84 |
+
for col, samples in metadata['samples'].items():
|
85 |
+
if samples:
|
86 |
+
csv_description += f"{col}: {', '.join(str(s) for s in samples[:2])}; "
|
87 |
+
|
88 |
+
# Get embedding for this description
|
89 |
+
csv_embedding = self.embed_model.get_text_embedding(csv_description)
|
90 |
+
|
91 |
+
# Calculate cosine similarity
|
92 |
+
similarity = self._cosine_similarity(query_embedding, csv_embedding)
|
93 |
+
similarities[file_id] = similarity
|
94 |
+
|
95 |
+
# Sort by similarity and return top_k
|
96 |
+
sorted_files = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
|
97 |
+
return [file_id for file_id, _ in sorted_files[:top_k]]
|
98 |
+
|
99 |
+
def _cosine_similarity(self, vec1, vec2):
|
100 |
+
"""Calculate cosine similarity between two vectors."""
|
101 |
+
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
102 |
+
norm_a = sum(a * a for a in vec1) ** 0.5
|
103 |
+
norm_b = sum(b * b for b in vec2) ** 0.5
|
104 |
+
return dot_product / (norm_a * norm_b) if norm_a * norm_b != 0 else 0
|
indexes/query_engine.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Any
|
2 |
+
from llama_index.query_engine import RetrieverQueryEngine
|
3 |
+
from llama_index.retrievers import VectorIndexRetriever
|
4 |
+
from llama_index.response_synthesizers import ResponseMode
|
5 |
+
from llama_index.llms import HuggingFaceLLM
|
6 |
+
from llama_index import ServiceContext, QueryBundle
|
7 |
+
from llama_index.prompts import PromptTemplate
|
8 |
+
|
9 |
+
class CSVQueryEngine:
|
10 |
+
"""Query engine for CSV data with multi-file support."""
|
11 |
+
|
12 |
+
def __init__(self, index_manager, llm, response_mode="compact"):
|
13 |
+
"""Initialize with index manager and language model."""
|
14 |
+
self.index_manager = index_manager
|
15 |
+
self.llm = llm
|
16 |
+
self.service_context = ServiceContext.from_defaults(llm=llm)
|
17 |
+
self.response_mode = response_mode
|
18 |
+
|
19 |
+
# Set up custom prompts
|
20 |
+
self._setup_prompts()
|
21 |
+
|
22 |
+
def _setup_prompts(self):
|
23 |
+
"""Set up custom prompts for CSV querying."""
|
24 |
+
self.csv_query_prompt = PromptTemplate(
|
25 |
+
"""You are an AI assistant specialized in analyzing CSV data.
|
26 |
+
Answer the following query using the provided CSV information.
|
27 |
+
If calculations are needed, explain your process.
|
28 |
+
|
29 |
+
CSV Context: {context_str}
|
30 |
+
Query: {query_str}
|
31 |
+
|
32 |
+
Answer:"""
|
33 |
+
)
|
34 |
+
|
35 |
+
def query(self, query_text: str) -> Dict[str, Any]:
|
36 |
+
"""Process a natural language query across CSV files."""
|
37 |
+
# Find relevant CSV files
|
38 |
+
relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
|
39 |
+
|
40 |
+
if not relevant_csvs:
|
41 |
+
return {
|
42 |
+
"answer": "No relevant CSV files found for your query.",
|
43 |
+
"sources": []
|
44 |
+
}
|
45 |
+
|
46 |
+
# Prepare response
|
47 |
+
responses = []
|
48 |
+
sources = []
|
49 |
+
|
50 |
+
# Query each relevant CSV
|
51 |
+
for csv_id in relevant_csvs:
|
52 |
+
index_info = self.index_manager.indexes.get(csv_id)
|
53 |
+
if not index_info:
|
54 |
+
continue
|
55 |
+
|
56 |
+
index = index_info["index"]
|
57 |
+
metadata = index_info["metadata"]
|
58 |
+
|
59 |
+
# Create retriever for this index
|
60 |
+
retriever = VectorIndexRetriever(
|
61 |
+
index=index,
|
62 |
+
similarity_top_k=5
|
63 |
+
)
|
64 |
+
|
65 |
+
# Create query engine
|
66 |
+
query_engine = RetrieverQueryEngine.from_args(
|
67 |
+
retriever=retriever,
|
68 |
+
service_context=self.service_context,
|
69 |
+
text_qa_template=self.csv_query_prompt,
|
70 |
+
response_mode=self.response_mode
|
71 |
+
)
|
72 |
+
|
73 |
+
# Execute query
|
74 |
+
response = query_engine.query(query_text)
|
75 |
+
|
76 |
+
responses.append({
|
77 |
+
"csv_id": csv_id,
|
78 |
+
"filename": metadata["filename"],
|
79 |
+
"response": response
|
80 |
+
})
|
81 |
+
|
82 |
+
# Collect source information
|
83 |
+
if hasattr(response, "source_nodes"):
|
84 |
+
for node in response.source_nodes:
|
85 |
+
sources.append({
|
86 |
+
"csv": metadata["filename"],
|
87 |
+
"content": node.node.get_content()[:100] + "..."
|
88 |
+
})
|
89 |
+
|
90 |
+
# Combine responses if multiple CSVs were queried
|
91 |
+
if len(responses) > 1:
|
92 |
+
combined_response = self._combine_responses(query_text, responses)
|
93 |
+
return {
|
94 |
+
"answer": combined_response,
|
95 |
+
"sources": sources
|
96 |
+
}
|
97 |
+
elif len(responses) == 1:
|
98 |
+
return {
|
99 |
+
"answer": responses[0]["response"],
|
100 |
+
"sources": sources
|
101 |
+
}
|
102 |
+
else:
|
103 |
+
return {
|
104 |
+
"answer": "Failed to process query with the available CSV data.",
|
105 |
+
"sources": []
|
106 |
+
}
|
107 |
+
|
108 |
+
def _combine_responses(self, query_text: str, responses: List[Dict]) -> str:
|
109 |
+
"""Combine responses from multiple CSV files."""
|
110 |
+
# Create a prompt for combining multiple CSV responses
|
111 |
+
combine_prompt = f"""
|
112 |
+
I need to answer this question: {query_text}
|
113 |
+
|
114 |
+
I've analyzed multiple CSV files and found these results:
|
115 |
+
|
116 |
+
{chr(10).join([f"From {r['filename']}: {str(r['response'])}" for r in responses])}
|
117 |
+
|
118 |
+
Please provide a unified answer that combines these insights.
|
119 |
+
"""
|
120 |
+
|
121 |
+
# Use the LLM to generate a combined response
|
122 |
+
combined_response = self.llm.complete(combine_prompt)
|
123 |
+
|
124 |
+
return combined_response.text
|
models/llm_setup.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from llama_index.llms import HuggingFaceLLM
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
5 |
+
|
6 |
+
def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct",
|
7 |
+
device: str = None,
|
8 |
+
context_window: int = 4096,
|
9 |
+
max_new_tokens: int = 512) -> HuggingFaceLLM:
|
10 |
+
"""
|
11 |
+
Set up the language model for the CSV chatbot.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
model_name: Name of the Hugging Face model to use
|
15 |
+
device: Device to run the model on ('cuda', 'cpu', etc.)
|
16 |
+
context_window: Maximum context window size
|
17 |
+
max_new_tokens: Maximum number of new tokens to generate
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Configured LLM instance
|
21 |
+
"""
|
22 |
+
# Determine device
|
23 |
+
if device is None:
|
24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
|
26 |
+
# Configure quantization for memory efficiency
|
27 |
+
if device == "cuda":
|
28 |
+
quantization_config = BitsAndBytesConfig(
|
29 |
+
load_in_4bit=True,
|
30 |
+
bnb_4bit_compute_dtype=torch.float16
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
quantization_config = None
|
34 |
+
|
35 |
+
# Configure tokenizer
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
37 |
+
model_name,
|
38 |
+
trust_remote_code=True
|
39 |
+
)
|
40 |
+
|
41 |
+
# Configure model with appropriate parameters for HF Spaces
|
42 |
+
model_kwargs = {
|
43 |
+
"trust_remote_code": True,
|
44 |
+
"torch_dtype": torch.float16,
|
45 |
+
}
|
46 |
+
|
47 |
+
if quantization_config:
|
48 |
+
model_kwargs["quantization_config"] = quantization_config
|
49 |
+
|
50 |
+
# Initialize LLM
|
51 |
+
llm = HuggingFaceLLM(
|
52 |
+
model_name=model_name,
|
53 |
+
tokenizer_name=model_name,
|
54 |
+
context_window=context_window,
|
55 |
+
max_new_tokens=max_new_tokens,
|
56 |
+
generate_kwargs={"temperature": 0.7, "top_p": 0.95},
|
57 |
+
device_map=device,
|
58 |
+
tokenizer_kwargs={"trust_remote_code": True},
|
59 |
+
model_kwargs=model_kwargs,
|
60 |
+
# Cache the model to avoid reloading
|
61 |
+
cache_folder="./model_cache"
|
62 |
+
)
|
63 |
+
|
64 |
+
return llm
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llama-index
|
2 |
+
transformers
|
3 |
+
gradio
|
4 |
+
pandas
|
5 |
+
numpy
|
6 |
+
matplotlib
|
7 |
+
plotly
|
8 |
+
sentence-transformers
|
9 |
+
chromadb
|
10 |
+
torch
|
11 |
+
pillow
|
12 |
+
chardet
|
13 |
+
bitsandbytes
|
14 |
+
accelerate
|
tools/data_tools.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional, Callable
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from llama_index.tools import FunctionTool
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
class PandasDataTools:
|
8 |
+
"""Tools for data analysis operations on CSV files."""
|
9 |
+
|
10 |
+
def __init__(self, csv_directory: str):
|
11 |
+
"""Initialize with directory containing CSV files."""
|
12 |
+
self.csv_directory = csv_directory
|
13 |
+
self.dataframes = {}
|
14 |
+
self.tools = self._create_tools()
|
15 |
+
|
16 |
+
def _load_dataframe(self, filename: str) -> pd.DataFrame:
|
17 |
+
"""Load a CSV file as DataFrame, with caching."""
|
18 |
+
if filename not in self.dataframes:
|
19 |
+
file_path = Path(self.csv_directory) / filename
|
20 |
+
if not file_path.exists() and not filename.endswith('.csv'):
|
21 |
+
file_path = Path(self.csv_directory) / f"{filename}.csv"
|
22 |
+
|
23 |
+
if file_path.exists():
|
24 |
+
self.dataframes[filename] = pd.read_csv(file_path)
|
25 |
+
else:
|
26 |
+
raise ValueError(f"CSV file not found: {filename}")
|
27 |
+
|
28 |
+
return self.dataframes[filename]
|
29 |
+
|
30 |
+
def _create_tools(self) -> List[FunctionTool]:
|
31 |
+
"""Create LlamaIndex function tools for data operations."""
|
32 |
+
tools = [
|
33 |
+
FunctionTool.from_defaults(
|
34 |
+
name="describe_csv",
|
35 |
+
description="Get statistical description of a CSV file",
|
36 |
+
fn=self.describe_csv
|
37 |
+
),
|
38 |
+
FunctionTool.from_defaults(
|
39 |
+
name="filter_data",
|
40 |
+
description="Filter CSV data based on conditions",
|
41 |
+
fn=self.filter_data
|
42 |
+
),
|
43 |
+
FunctionTool.from_defaults(
|
44 |
+
name="group_and_aggregate",
|
45 |
+
description="Group data and calculate aggregate statistics",
|
46 |
+
fn=self.group_and_aggregate
|
47 |
+
),
|
48 |
+
FunctionTool.from_defaults(
|
49 |
+
name="sort_data",
|
50 |
+
description="Sort data by specified columns",
|
51 |
+
fn=self.sort_data
|
52 |
+
),
|
53 |
+
FunctionTool.from_defaults(
|
54 |
+
name="calculate_correlation",
|
55 |
+
description="Calculate correlation between columns",
|
56 |
+
fn=self.calculate_correlation
|
57 |
+
)
|
58 |
+
]
|
59 |
+
return tools
|
60 |
+
|
61 |
+
def get_tools(self) -> List[FunctionTool]:
|
62 |
+
"""Get all available data tools."""
|
63 |
+
return self.tools
|
64 |
+
|
65 |
+
# Tool implementations
|
66 |
+
def describe_csv(self, filename: str) -> Dict[str, Any]:
|
67 |
+
"""Get statistical description of CSV data."""
|
68 |
+
df = self._load_dataframe(filename)
|
69 |
+
description = df.describe().to_dict()
|
70 |
+
|
71 |
+
# Add additional info
|
72 |
+
result = {
|
73 |
+
"statistics": description,
|
74 |
+
"shape": df.shape,
|
75 |
+
"columns": df.columns.tolist(),
|
76 |
+
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
|
77 |
+
}
|
78 |
+
|
79 |
+
return result
|
80 |
+
|
81 |
+
def filter_data(self, filename: str, column: str, condition: str, value: Any) -> Dict[str, Any]:
|
82 |
+
"""Filter data based on condition (==, >, <, >=, <=, !=, contains)."""
|
83 |
+
df = self._load_dataframe(filename)
|
84 |
+
|
85 |
+
if condition == "==":
|
86 |
+
filtered = df[df[column] == value]
|
87 |
+
elif condition == ">":
|
88 |
+
filtered = df[df[column] > float(value)]
|
89 |
+
elif condition == "<":
|
90 |
+
filtered = df[df[column] < float(value)]
|
91 |
+
elif condition == ">=":
|
92 |
+
filtered = df[df[column] >= float(value)]
|
93 |
+
elif condition == "<=":
|
94 |
+
filtered = df[df[column] <= float(value)]
|
95 |
+
elif condition == "!=":
|
96 |
+
filtered = df[df[column] != value]
|
97 |
+
elif condition.lower() == "contains":
|
98 |
+
filtered = df[df[column].astype(str).str.contains(str(value))]
|
99 |
+
else:
|
100 |
+
return {"error": f"Unsupported condition: {condition}"}
|
101 |
+
|
102 |
+
return {
|
103 |
+
"result_count": len(filtered),
|
104 |
+
"results": filtered.head(10).to_dict(orient="records"),
|
105 |
+
"total_count": len(df)
|
106 |
+
}
|
107 |
+
|
108 |
+
def group_and_aggregate(self, filename: str, group_by: str, agg_column: str,
|
109 |
+
agg_function: str = "mean") -> Dict[str, Any]:
|
110 |
+
"""Group by column and calculate aggregate statistic."""
|
111 |
+
df = self._load_dataframe(filename)
|
112 |
+
|
113 |
+
agg_functions = {
|
114 |
+
"mean": np.mean,
|
115 |
+
"sum": np.sum,
|
116 |
+
"min": np.min,
|
117 |
+
"max": np.max,
|
118 |
+
"count": len,
|
119 |
+
"median": np.median
|
120 |
+
}
|
121 |
+
|
122 |
+
if agg_function not in agg_functions:
|
123 |
+
return {"error": f"Unsupported aggregation function: {agg_function}"}
|
124 |
+
|
125 |
+
grouped = df.groupby(group_by)[agg_column].agg(agg_functions[agg_function])
|
126 |
+
|
127 |
+
return {
|
128 |
+
"group_by": group_by,
|
129 |
+
"aggregated_column": agg_column,
|
130 |
+
"aggregation": agg_function,
|
131 |
+
"results": grouped.to_dict()
|
132 |
+
}
|
133 |
+
|
134 |
+
def sort_data(self, filename: str, sort_by: str, ascending: bool = True) -> Dict[str, Any]:
|
135 |
+
"""Sort data by column."""
|
136 |
+
df = self._load_dataframe(filename)
|
137 |
+
|
138 |
+
sorted_df = df.sort_values(by=sort_by, ascending=ascending)
|
139 |
+
|
140 |
+
return {
|
141 |
+
"sorted_by": sort_by,
|
142 |
+
"ascending": ascending,
|
143 |
+
"results": sorted_df.head(10).to_dict(orient="records")
|
144 |
+
}
|
145 |
+
|
146 |
+
def calculate_correlation(self, filename: str, column1: str, column2: str) -> Dict[str, Any]:
|
147 |
+
"""Calculate correlation between two columns."""
|
148 |
+
df = self._load_dataframe(filename)
|
149 |
+
|
150 |
+
try:
|
151 |
+
correlation = df[column1].corr(df[column2])
|
152 |
+
return {
|
153 |
+
"correlation": correlation,
|
154 |
+
"column1": column1,
|
155 |
+
"column2": column2
|
156 |
+
}
|
157 |
+
except Exception as e:
|
158 |
+
return {"error": f"Could not calculate correlation: {str(e)}"}
|
tools/export.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional, Union
|
2 |
+
import pandas as pd
|
3 |
+
import smtplib
|
4 |
+
from email.mime.multipart import MIMEMultipart
|
5 |
+
from email.mime.text import MIMEText
|
6 |
+
from email.mime.image import MIMEImage
|
7 |
+
import base64
|
8 |
+
import io
|
9 |
+
from pathlib import Path
|
10 |
+
import json
|
11 |
+
import datetime
|
12 |
+
from llama_index.tools import FunctionTool
|
13 |
+
|
14 |
+
class ExportTools:
|
15 |
+
"""Tools for exporting data, generating reports, and sending emails."""
|
16 |
+
|
17 |
+
def __init__(self, output_directory: str = "./exports"):
|
18 |
+
"""Initialize with directory for saved exports."""
|
19 |
+
self.output_directory = Path(output_directory)
|
20 |
+
self.output_directory.mkdir(exist_ok=True, parents=True)
|
21 |
+
self.tools = self._create_tools()
|
22 |
+
|
23 |
+
def _create_tools(self) -> List[FunctionTool]:
|
24 |
+
"""Create LlamaIndex function tools for export operations."""
|
25 |
+
tools = [
|
26 |
+
FunctionTool.from_defaults(
|
27 |
+
name="generate_report",
|
28 |
+
description="Generate a report from conversation and results",
|
29 |
+
fn=self.generate_report
|
30 |
+
),
|
31 |
+
FunctionTool.from_defaults(
|
32 |
+
name="save_results_to_csv",
|
33 |
+
description="Save query results to a CSV file",
|
34 |
+
fn=self.save_results_to_csv
|
35 |
+
),
|
36 |
+
FunctionTool.from_defaults(
|
37 |
+
name="send_email",
|
38 |
+
description="Send results via email",
|
39 |
+
fn=self.send_email
|
40 |
+
)
|
41 |
+
]
|
42 |
+
return tools
|
43 |
+
|
44 |
+
def get_tools(self) -> List[FunctionTool]:
|
45 |
+
"""Get all available export tools."""
|
46 |
+
return self.tools
|
47 |
+
|
48 |
+
def generate_report(self, title: str, content: str,
|
49 |
+
images: List[str] = None) -> Dict[str, Any]:
|
50 |
+
"""Generate HTML report from content and images."""
|
51 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
52 |
+
filename = f"{title.replace(' ', '_')}_{timestamp}.html"
|
53 |
+
file_path = self.output_directory / filename
|
54 |
+
|
55 |
+
# Basic HTML template
|
56 |
+
html = f"""
|
57 |
+
<!DOCTYPE html>
|
58 |
+
<html>
|
59 |
+
<head>
|
60 |
+
<title>{title}</title>
|
61 |
+
<style>
|
62 |
+
body {{ font-family: Arial, sans-serif; margin: 40px; line-height: 1.6; }}
|
63 |
+
h1 {{ color: #333366; }}
|
64 |
+
.report-container {{ max-width: 800px; margin: 0 auto; }}
|
65 |
+
.timestamp {{ color: #666; font-size: 0.8em; }}
|
66 |
+
img {{ max-width: 100%; height: auto; margin: 20px 0; }}
|
67 |
+
</style>
|
68 |
+
</head>
|
69 |
+
<body>
|
70 |
+
<div class="report-container">
|
71 |
+
<h1>{title}</h1>
|
72 |
+
<div class="timestamp">Generated on: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</div>
|
73 |
+
<div class="content">
|
74 |
+
{content.replace('\n', '<br>')}
|
75 |
+
</div>
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Add images if provided
|
79 |
+
if images and len(images) > 0:
|
80 |
+
html += "<div class='images'>"
|
81 |
+
for i, img_base64 in enumerate(images):
|
82 |
+
html += f"<img src='data:image/png;base64,{img_base64}' alt='Figure {i+1}'>"
|
83 |
+
html += "</div>"
|
84 |
+
|
85 |
+
html += """
|
86 |
+
</div>
|
87 |
+
</body>
|
88 |
+
</html>
|
89 |
+
"""
|
90 |
+
|
91 |
+
# Write to file
|
92 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
93 |
+
f.write(html)
|
94 |
+
|
95 |
+
return {
|
96 |
+
"success": True,
|
97 |
+
"report_path": str(file_path),
|
98 |
+
"title": title,
|
99 |
+
"timestamp": timestamp
|
100 |
+
}
|
101 |
+
|
102 |
+
def save_results_to_csv(self, data: List[Dict[str, Any]],
|
103 |
+
filename: str = None) -> Dict[str, Any]:
|
104 |
+
"""Save query results to a CSV file."""
|
105 |
+
if not data or len(data) == 0:
|
106 |
+
return {"success": False, "error": "No data provided"}
|
107 |
+
|
108 |
+
# Create DataFrame from data
|
109 |
+
df = pd.DataFrame(data)
|
110 |
+
|
111 |
+
# Generate filename if not provided
|
112 |
+
if not filename:
|
113 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
114 |
+
filename = f"query_results_{timestamp}.csv"
|
115 |
+
|
116 |
+
# Ensure filename has .csv extension
|
117 |
+
if not filename.lower().endswith('.csv'):
|
118 |
+
filename += '.csv'
|
119 |
+
|
120 |
+
file_path = self.output_directory / filename
|
121 |
+
|
122 |
+
# Save to CSV
|
123 |
+
df.to_csv(file_path, index=False)
|
124 |
+
|
125 |
+
return {
|
126 |
+
"success": True,
|
127 |
+
"file_path": str(file_path),
|
128 |
+
"row_count": len(df),
|
129 |
+
"column_count": len(df.columns)
|
130 |
+
}
|
131 |
+
|
132 |
+
def send_email(self, to_email: str, subject: str, body: str,
|
133 |
+
from_email: str = None, smtp_server: str = None,
|
134 |
+
smtp_port: int = 587, username: str = None,
|
135 |
+
password: str = None, images: List[str] = None) -> Dict[str, Any]:
|
136 |
+
"""
|
137 |
+
Send email with results.
|
138 |
+
Note: In production, credentials should be securely managed.
|
139 |
+
For demo purposes, this will log the email content instead.
|
140 |
+
"""
|
141 |
+
# For safety in a demo app, don't actually send emails
|
142 |
+
# Just log what would be sent and return success
|
143 |
+
|
144 |
+
email_content = {
|
145 |
+
"to": to_email,
|
146 |
+
"subject": subject,
|
147 |
+
"body": body[:100] + "..." if len(body) > 100 else body,
|
148 |
+
"images": f"{len(images) if images else 0} images would be attached",
|
149 |
+
"note": "Email sending is simulated for demo purposes"
|
150 |
+
}
|
151 |
+
|
152 |
+
# Log the email content
|
153 |
+
print(f"SIMULATED EMAIL: {json.dumps(email_content, indent=2)}")
|
154 |
+
|
155 |
+
return {
|
156 |
+
"success": True,
|
157 |
+
"to": to_email,
|
158 |
+
"subject": subject,
|
159 |
+
"simulated": True,
|
160 |
+
"timestamp": datetime.datetime.now().isoformat()
|
161 |
+
}
|
tools/visualization.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional, Tuple, Union
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import numpy as np
|
8 |
+
from llama_index.tools import FunctionTool
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Configure matplotlib for non-interactive environments
|
12 |
+
matplotlib.use('Agg')
|
13 |
+
|
14 |
+
class VisualizationTools:
|
15 |
+
"""Tools for creating visualizations from CSV data."""
|
16 |
+
|
17 |
+
def __init__(self, csv_directory: str):
|
18 |
+
"""Initialize with directory containing CSV files."""
|
19 |
+
self.csv_directory = csv_directory
|
20 |
+
self.dataframes = {}
|
21 |
+
self.tools = self._create_tools()
|
22 |
+
self.figure_size = (10, 6)
|
23 |
+
self.dpi = 100
|
24 |
+
|
25 |
+
def _load_dataframe(self, filename: str) -> pd.DataFrame:
|
26 |
+
"""Load a CSV file as DataFrame, with caching."""
|
27 |
+
if filename not in self.dataframes:
|
28 |
+
file_path = Path(self.csv_directory) / filename
|
29 |
+
if not file_path.exists() and not filename.endswith('.csv'):
|
30 |
+
file_path = Path(self.csv_directory) / f"{filename}.csv"
|
31 |
+
|
32 |
+
if file_path.exists():
|
33 |
+
self.dataframes[filename] = pd.read_csv(file_path)
|
34 |
+
else:
|
35 |
+
raise ValueError(f"CSV file not found: {filename}")
|
36 |
+
|
37 |
+
return self.dataframes[filename]
|
38 |
+
|
39 |
+
def _create_tools(self) -> List[FunctionTool]:
|
40 |
+
"""Create LlamaIndex function tools for visualizations."""
|
41 |
+
tools = [
|
42 |
+
FunctionTool.from_defaults(
|
43 |
+
name="create_line_chart",
|
44 |
+
description="Create a line chart from CSV data",
|
45 |
+
fn=self.create_line_chart
|
46 |
+
),
|
47 |
+
FunctionTool.from_defaults(
|
48 |
+
name="create_bar_chart",
|
49 |
+
description="Create a bar chart from CSV data",
|
50 |
+
fn=self.create_bar_chart
|
51 |
+
),
|
52 |
+
FunctionTool.from_defaults(
|
53 |
+
name="create_scatter_plot",
|
54 |
+
description="Create a scatter plot from CSV data",
|
55 |
+
fn=self.create_scatter_plot
|
56 |
+
),
|
57 |
+
FunctionTool.from_defaults(
|
58 |
+
name="create_histogram",
|
59 |
+
description="Create a histogram from CSV data",
|
60 |
+
fn=self.create_histogram
|
61 |
+
),
|
62 |
+
FunctionTool.from_defaults(
|
63 |
+
name="create_pie_chart",
|
64 |
+
description="Create a pie chart from CSV data",
|
65 |
+
fn=self.create_pie_chart
|
66 |
+
)
|
67 |
+
]
|
68 |
+
return tools
|
69 |
+
|
70 |
+
def get_tools(self) -> List[FunctionTool]:
|
71 |
+
"""Get all available visualization tools."""
|
72 |
+
return self.tools
|
73 |
+
|
74 |
+
def _figure_to_base64(self, fig) -> str:
|
75 |
+
"""Convert matplotlib figure to base64 encoded string."""
|
76 |
+
buf = io.BytesIO()
|
77 |
+
fig.savefig(buf, format='png', dpi=self.dpi)
|
78 |
+
buf.seek(0)
|
79 |
+
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
80 |
+
plt.close(fig)
|
81 |
+
return img_str
|
82 |
+
|
83 |
+
# Visualization tool implementations
|
84 |
+
def create_line_chart(self, filename: str, x_column: str, y_column: str,
|
85 |
+
title: str = None, limit: int = 50) -> Dict[str, Any]:
|
86 |
+
"""Create a line chart visualization."""
|
87 |
+
df = self._load_dataframe(filename)
|
88 |
+
|
89 |
+
# Limit data points if needed
|
90 |
+
if len(df) > limit:
|
91 |
+
df = df.head(limit)
|
92 |
+
|
93 |
+
fig, ax = plt.subplots(figsize=self.figure_size)
|
94 |
+
|
95 |
+
# Create line chart
|
96 |
+
ax.plot(df[x_column], df[y_column], marker='o', linestyle='-')
|
97 |
+
|
98 |
+
# Set labels and title
|
99 |
+
ax.set_xlabel(x_column)
|
100 |
+
ax.set_ylabel(y_column)
|
101 |
+
ax.set_title(title or f"{y_column} vs {x_column}")
|
102 |
+
ax.grid(True)
|
103 |
+
|
104 |
+
# Convert to base64
|
105 |
+
img_str = self._figure_to_base64(fig)
|
106 |
+
|
107 |
+
return {
|
108 |
+
"chart_type": "line",
|
109 |
+
"x_column": x_column,
|
110 |
+
"y_column": y_column,
|
111 |
+
"data_points": len(df),
|
112 |
+
"image": img_str
|
113 |
+
}
|
114 |
+
|
115 |
+
def create_bar_chart(self, filename: str, x_column: str, y_column: str,
|
116 |
+
title: str = None, limit: int = 20) -> Dict[str, Any]:
|
117 |
+
"""Create a bar chart visualization."""
|
118 |
+
df = self._load_dataframe(filename)
|
119 |
+
|
120 |
+
# Limit categories if needed
|
121 |
+
if len(df) > limit:
|
122 |
+
df = df.head(limit)
|
123 |
+
|
124 |
+
fig, ax = plt.subplots(figsize=self.figure_size)
|
125 |
+
|
126 |
+
# Create bar chart
|
127 |
+
ax.bar(df[x_column], df[y_column])
|
128 |
+
|
129 |
+
# Set labels and title
|
130 |
+
ax.set_xlabel(x_column)
|
131 |
+
ax.set_ylabel(y_column)
|
132 |
+
ax.set_title(title or f"{y_column} by {x_column}")
|
133 |
+
|
134 |
+
# Rotate x labels if there are many categories
|
135 |
+
if len(df) > 5:
|
136 |
+
plt.xticks(rotation=45, ha='right')
|
137 |
+
|
138 |
+
plt.tight_layout()
|
139 |
+
|
140 |
+
# Convert to base64
|
141 |
+
img_str = self._figure_to_base64(fig)
|
142 |
+
|
143 |
+
return {
|
144 |
+
"chart_type": "bar",
|
145 |
+
"x_column": x_column,
|
146 |
+
"y_column": y_column,
|
147 |
+
"categories": len(df),
|
148 |
+
"image": img_str
|
149 |
+
}
|
150 |
+
|
151 |
+
def create_scatter_plot(self, filename: str, x_column: str, y_column: str,
|
152 |
+
color_column: str = None, title: str = None) -> Dict[str, Any]:
|
153 |
+
"""Create a scatter plot visualization."""
|
154 |
+
df = self._load_dataframe(filename)
|
155 |
+
|
156 |
+
fig, ax = plt.subplots(figsize=self.figure_size)
|
157 |
+
|
158 |
+
# Create scatter plot
|
159 |
+
if color_column and color_column in df.columns:
|
160 |
+
scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7)
|
161 |
+
plt.colorbar(scatter, ax=ax, label=color_column)
|
162 |
+
else:
|
163 |
+
ax.scatter(df[x_column], df[y_column], alpha=0.7)
|
164 |
+
|
165 |
+
# Set labels and title
|
166 |
+
ax.set_xlabel(x_column)
|
167 |
+
ax.set_ylabel(y_column)
|
168 |
+
ax.set_title(title or f"{y_column} vs {x_column}")
|
169 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
170 |
+
|
171 |
+
# Convert to base64
|
172 |
+
img_str = self._figure_to_base64(fig)
|
173 |
+
|
174 |
+
return {
|
175 |
+
"chart_type": "scatter",
|
176 |
+
"x_column": x_column,
|
177 |
+
"y_column": y_column,
|
178 |
+
"color_column": color_column,
|
179 |
+
"data_points": len(df),
|
180 |
+
"image": img_str
|
181 |
+
}
|
182 |
+
|
183 |
+
def create_histogram(self, filename: str, column: str, bins: int = 10,
|
184 |
+
title: str = None) -> Dict[str, Any]:
|
185 |
+
"""Create a histogram visualization."""
|
186 |
+
df = self._load_dataframe(filename)
|
187 |
+
|
188 |
+
fig, ax = plt.subplots(figsize=self.figure_size)
|
189 |
+
|
190 |
+
# Create histogram
|
191 |
+
ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black')
|
192 |
+
|
193 |
+
# Set labels and title
|
194 |
+
ax.set_xlabel(column)
|
195 |
+
ax.set_ylabel('Frequency')
|
196 |
+
ax.set_title(title or f"Distribution of {column}")
|
197 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
198 |
+
|
199 |
+
# Convert to base64
|
200 |
+
img_str = self._figure_to_base64(fig)
|
201 |
+
|
202 |
+
return {
|
203 |
+
"chart_type": "histogram",
|
204 |
+
"column": column,
|
205 |
+
"bins": bins,
|
206 |
+
"data_points": len(df),
|
207 |
+
"image": img_str
|
208 |
+
}
|
209 |
+
|
210 |
+
def create_pie_chart(self, filename: str, label_column: str, value_column: str = None,
|
211 |
+
title: str = None, limit: int = 10) -> Dict[str, Any]:
|
212 |
+
"""Create a pie chart visualization."""
|
213 |
+
df = self._load_dataframe(filename)
|
214 |
+
|
215 |
+
# If value column not provided, count occurrences of each label
|
216 |
+
if value_column is None:
|
217 |
+
data = df[label_column].value_counts().head(limit)
|
218 |
+
labels = data.index.tolist()
|
219 |
+
values = data.values.tolist()
|
220 |
+
else:
|
221 |
+
# Group by label and sum values
|
222 |
+
grouped = df.groupby(label_column)[value_column].sum().reset_index()
|
223 |
+
# Limit to top categories
|
224 |
+
grouped = grouped.nlargest(limit, value_column)
|
225 |
+
labels = grouped[label_column].tolist()
|
226 |
+
values = grouped[value_column].tolist()
|
227 |
+
|
228 |
+
fig, ax = plt.subplots(figsize=self.figure_size)
|
229 |
+
|
230 |
+
# Create pie chart
|
231 |
+
ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True)
|
232 |
+
ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
|
233 |
+
|
234 |
+
# Set title
|
235 |
+
ax.set_title(title or f"Distribution of {label_column}")
|
236 |
+
|
237 |
+
# Convert to base64
|
238 |
+
img_str = self._figure_to_base64(fig)
|
239 |
+
|
240 |
+
return {
|
241 |
+
"chart_type": "pie",
|
242 |
+
"label_column": label_column,
|
243 |
+
"value_column": value_column,
|
244 |
+
"categories": len(labels),
|
245 |
+
"image": img_str
|
246 |
+
}
|
utils/csv_helper.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional, Tuple
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
import chardet
|
7 |
+
import csv
|
8 |
+
|
9 |
+
class CSVHelpers:
|
10 |
+
"""Helper utilities for CSV preprocessing and analysis."""
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def detect_encoding(file_path: str, sample_size: int = 10000) -> str:
|
14 |
+
"""Detect the encoding of a CSV file."""
|
15 |
+
with open(file_path, 'rb') as f:
|
16 |
+
raw_data = f.read(sample_size)
|
17 |
+
result = chardet.detect(raw_data)
|
18 |
+
return result['encoding']
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def detect_delimiter(file_path: str, encoding: str = 'utf-8') -> str:
|
22 |
+
"""Detect the delimiter used in a CSV file."""
|
23 |
+
with open(file_path, 'r', encoding=encoding) as csvfile:
|
24 |
+
sample = csvfile.read(4096)
|
25 |
+
|
26 |
+
# Check common delimiters
|
27 |
+
for delimiter in [',', ';', '\t', '|']:
|
28 |
+
sniffer = csv.Sniffer()
|
29 |
+
try:
|
30 |
+
if delimiter in sample:
|
31 |
+
dialect = sniffer.sniff(sample, delimiters=delimiter)
|
32 |
+
return dialect.delimiter
|
33 |
+
except:
|
34 |
+
continue
|
35 |
+
|
36 |
+
# Default to comma if detection fails
|
37 |
+
return ','
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def preprocess_csv(file_path: str) -> Tuple[pd.DataFrame, Dict[str, Any]]:
|
41 |
+
"""
|
42 |
+
Preprocess a CSV file with automatic encoding and delimiter detection.
|
43 |
+
Returns the DataFrame and metadata about the preprocessing.
|
44 |
+
"""
|
45 |
+
# Detect encoding
|
46 |
+
try:
|
47 |
+
encoding = CSVHelpers.detect_encoding(file_path)
|
48 |
+
except:
|
49 |
+
encoding = 'utf-8' # Default to UTF-8 if detection fails
|
50 |
+
|
51 |
+
# Detect delimiter
|
52 |
+
try:
|
53 |
+
delimiter = CSVHelpers.detect_delimiter(file_path, encoding)
|
54 |
+
except:
|
55 |
+
delimiter = ',' # Default to comma if detection fails
|
56 |
+
|
57 |
+
# Read the CSV with detected parameters
|
58 |
+
df = pd.read_csv(file_path, encoding=encoding, delimiter=delimiter, low_memory=False)
|
59 |
+
|
60 |
+
# Basic preprocessing
|
61 |
+
metadata = {
|
62 |
+
"original_shape": df.shape,
|
63 |
+
"encoding": encoding,
|
64 |
+
"delimiter": delimiter,
|
65 |
+
"columns": list(df.columns),
|
66 |
+
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
|
67 |
+
}
|
68 |
+
|
69 |
+
# Handle missing values
|
70 |
+
missing_counts = df.isna().sum()
|
71 |
+
metadata["missing_values"] = {col: int(count) for col, count in missing_counts.items() if count > 0}
|
72 |
+
|
73 |
+
# Handle duplicate rows
|
74 |
+
duplicates = df.duplicated().sum()
|
75 |
+
metadata["duplicate_rows"] = int(duplicates)
|
76 |
+
|
77 |
+
return df, metadata
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def infer_column_types(df: pd.DataFrame) -> Dict[str, str]:
|
81 |
+
"""
|
82 |
+
Infer semantic types of columns (beyond pandas dtypes).
|
83 |
+
Examples: date, categorical, numeric, text, etc.
|
84 |
+
"""
|
85 |
+
column_types = {}
|
86 |
+
|
87 |
+
for column in df.columns:
|
88 |
+
# Skip columns with all missing values
|
89 |
+
if df[column].isna().all():
|
90 |
+
column_types[column] = "unknown"
|
91 |
+
continue
|
92 |
+
|
93 |
+
# Get pandas dtype
|
94 |
+
dtype = df[column].dtype
|
95 |
+
|
96 |
+
# Check if datetime
|
97 |
+
if pd.api.types.is_datetime64_dtype(df[column]):
|
98 |
+
column_types[column] = "datetime"
|
99 |
+
|
100 |
+
# Try to convert to datetime if string
|
101 |
+
elif dtype == 'object':
|
102 |
+
try:
|
103 |
+
# Sample non-null values
|
104 |
+
sample = df[column].dropna().head(10)
|
105 |
+
pd.to_datetime(sample)
|
106 |
+
column_types[column] = "potential_datetime"
|
107 |
+
except:
|
108 |
+
# Check if categorical (few unique values)
|
109 |
+
unique_ratio = df[column].nunique() / len(df)
|
110 |
+
if unique_ratio < 0.1: # Less than 10% unique values
|
111 |
+
column_types[column] = "categorical"
|
112 |
+
else:
|
113 |
+
column_types[column] = "text"
|
114 |
+
|
115 |
+
# Numeric types
|
116 |
+
elif pd.api.types.is_numeric_dtype(dtype):
|
117 |
+
# Check if potential ID column
|
118 |
+
if df[column].nunique() == len(df) and df[column].min() >= 0:
|
119 |
+
column_types[column] = "id"
|
120 |
+
# Check if binary
|
121 |
+
elif df[column].nunique() <= 2:
|
122 |
+
column_types[column] = "binary"
|
123 |
+
# Check if integer
|
124 |
+
elif pd.api.types.is_integer_dtype(dtype):
|
125 |
+
column_types[column] = "integer"
|
126 |
+
else:
|
127 |
+
column_types[column] = "float"
|
128 |
+
|
129 |
+
# Boolean type
|
130 |
+
elif pd.api.types.is_bool_dtype(dtype):
|
131 |
+
column_types[column] = "boolean"
|
132 |
+
|
133 |
+
# Fallback
|
134 |
+
else:
|
135 |
+
column_types[column] = "unknown"
|
136 |
+
|
137 |
+
return column_types
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def suggest_visualizations(df: pd.DataFrame) -> List[Dict[str, Any]]:
|
141 |
+
"""
|
142 |
+
Suggest appropriate visualizations based on data types.
|
143 |
+
Returns a list of visualization suggestions.
|
144 |
+
"""
|
145 |
+
suggestions = []
|
146 |
+
column_types = CSVHelpers.infer_column_types(df)
|
147 |
+
numeric_columns = [col for col, type in column_types.items()
|
148 |
+
if type in ["integer", "float"]]
|
149 |
+
categorical_columns = [col for col, type in column_types.items()
|
150 |
+
if type in ["categorical", "binary"]]
|
151 |
+
datetime_columns = [col for col, type in column_types.items()
|
152 |
+
if type in ["datetime", "potential_datetime"]]
|
153 |
+
|
154 |
+
# Histogram for numeric columns
|
155 |
+
for col in numeric_columns[:3]: # Limit to first 3
|
156 |
+
suggestions.append({
|
157 |
+
"chart_type": "histogram",
|
158 |
+
"column": col,
|
159 |
+
"title": f"Distribution of {col}"
|
160 |
+
})
|
161 |
+
|
162 |
+
# Bar charts for categorical columns
|
163 |
+
for col in categorical_columns[:3]: # Limit to first 3
|
164 |
+
suggestions.append({
|
165 |
+
"chart_type": "bar",
|
166 |
+
"x_column": col,
|
167 |
+
"y_column": "count",
|
168 |
+
"title": f"Count by {col}"
|
169 |
+
})
|
170 |
+
|
171 |
+
# Time series for datetime + numeric combinations
|
172 |
+
if datetime_columns and numeric_columns:
|
173 |
+
suggestions.append({
|
174 |
+
"chart_type": "line",
|
175 |
+
"x_column": datetime_columns[0],
|
176 |
+
"y_column": numeric_columns[0],
|
177 |
+
"title": f"{numeric_columns[0]} over Time"
|
178 |
+
})
|
179 |
+
|
180 |
+
# Scatter plots for numeric pairs
|
181 |
+
if len(numeric_columns) >= 2:
|
182 |
+
suggestions.append({
|
183 |
+
"chart_type": "scatter",
|
184 |
+
"x_column": numeric_columns[0],
|
185 |
+
"y_column": numeric_columns[1],
|
186 |
+
"title": f"{numeric_columns[1]} vs {numeric_columns[0]}"
|
187 |
+
})
|
188 |
+
|
189 |
+
return suggestions
|
utils/prompt_template.py
ADDED
File without changes
|