Chamin09 commited on
Commit
e13d87a
·
verified ·
1 Parent(s): 18f7294

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import tempfile
4
+ from pathlib import Path
5
+ import base64
6
+ from PIL import Image
7
+ import io
8
+ import time
9
+
10
+ # Import our components
11
+ from models.llm_setup import setup_llm
12
+ from indexes.csv_index_builder import EnhancedCSVReader
13
+ from indexes.index_manager import CSVIndexManager
14
+ from indexes.query_engine import CSVQueryEngine
15
+ from tools.data_tools import PandasDataTools
16
+ from tools.visualization import VisualizationTools
17
+ from tools.export import ExportTools
18
+
19
+ # Setup temporary directory for uploaded files
20
+ UPLOAD_DIR = Path(tempfile.mkdtemp())
21
+ EXPORT_DIR = Path(tempfile.mkdtemp())
22
+
23
+ class CSVChatApp:
24
+ """Main application class for CSV chatbot."""
25
+
26
+ def __init__(self):
27
+ """Initialize the application components."""
28
+ # Initialize the language model
29
+ self.llm = setup_llm()
30
+
31
+ # Initialize the index manager
32
+ self.index_manager = CSVIndexManager()
33
+
34
+ # Initialize tools
35
+ self.data_tools = PandasDataTools(str(UPLOAD_DIR))
36
+ self.viz_tools = VisualizationTools(str(UPLOAD_DIR))
37
+ self.export_tools = ExportTools(str(EXPORT_DIR))
38
+
39
+ # Initialize query engine with tools
40
+ self.query_engine = self._setup_query_engine()
41
+
42
+ # Track conversation history
43
+ self.chat_history = []
44
+ self.uploaded_files = []
45
+
46
+ def _setup_query_engine(self):
47
+ """Set up the query engine with tools."""
48
+ # Get all tools
49
+ tools = (
50
+ self.data_tools.get_tools() +
51
+ self.viz_tools.get_tools() +
52
+ self.export_tools.get_tools()
53
+ )
54
+
55
+ # Create query engine with tools
56
+ query_engine = CSVQueryEngine(self.index_manager, self.llm)
57
+
58
+ return query_engine
59
+
60
+ def handle_file_upload(self, files):
61
+ """Process uploaded CSV files."""
62
+ file_info = []
63
+
64
+ for file in files:
65
+ if file is None:
66
+ continue
67
+
68
+ # Get file path
69
+ file_path = Path(file.name)
70
+
71
+ # Only process CSV files
72
+ if not file_path.suffix.lower() == '.csv':
73
+ continue
74
+
75
+ # Copy to upload directory
76
+ dest_path = UPLOAD_DIR / file_path.name
77
+ with open(dest_path, 'wb') as f:
78
+ f.write(file_path.read_bytes())
79
+
80
+ # Create index for this file
81
+ try:
82
+ self.index_manager.create_index(str(dest_path))
83
+ file_info.append(f"✅ Indexed: {file_path.name}")
84
+ self.uploaded_files.append(str(dest_path))
85
+ except Exception as e:
86
+ file_info.append(f"❌ Failed to index {file_path.name}: {str(e)}")
87
+
88
+ # Return information about processed files
89
+ if file_info:
90
+ return "\n".join(file_info)
91
+ else:
92
+ return "No CSV files were uploaded."
93
+
94
+ def process_query(self, query, history):
95
+ """Process a user query and generate a response."""
96
+ if not self.uploaded_files:
97
+ return "Please upload CSV files before asking questions."
98
+
99
+ # Add user message to history
100
+ self.chat_history.append({"role": "user", "content": query})
101
+
102
+ # Process the query
103
+ try:
104
+ response = self.query_engine.query(query)
105
+ answer = response["answer"]
106
+
107
+ # Check if response contains an image
108
+ if isinstance(answer, dict) and "image" in answer:
109
+ # Handle image in response
110
+ img_data = answer["image"]
111
+ img = Image.open(io.BytesIO(base64.b64decode(img_data)))
112
+ img_path = EXPORT_DIR / f"viz_{int(time.time())}.png"
113
+ img.save(img_path)
114
+
115
+ # Update answer to include image path
116
+ text_response = answer.get("text", "Generated visualization")
117
+ answer = (text_response, str(img_path))
118
+
119
+ # Add assistant message to history
120
+ self.chat_history.append({"role": "assistant", "content": answer})
121
+
122
+ return answer
123
+
124
+ except Exception as e:
125
+ error_msg = f"Error processing query: {str(e)}"
126
+ self.chat_history.append({"role": "assistant", "content": error_msg})
127
+ return error_msg
128
+
129
+ def export_conversation(self):
130
+ """Export the conversation as a report."""
131
+ if not self.chat_history:
132
+ return "No conversation to export."
133
+
134
+ # Extract content for report
135
+ title = "CSV Chat Conversation Report"
136
+ content = ""
137
+ images = []
138
+
139
+ for msg in self.chat_history:
140
+ role = msg["role"]
141
+ content_text = msg["content"]
142
+
143
+ # Handle content that might contain images
144
+ if isinstance(content_text, tuple) and len(content_text) == 2:
145
+ text, img_path = content_text
146
+ content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {text}"
147
+
148
+ # Add image to report
149
+ try:
150
+ with open(img_path, "rb") as img_file:
151
+ img_data = base64.b64encode(img_file.read()).decode('utf-8')
152
+ images.append(img_data)
153
+ except Exception:
154
+ pass
155
+ else:
156
+ content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {content_text}"
157
+
158
+ # Generate report
159
+ result = self.export_tools.generate_report(title, content, images)
160
+
161
+ if result["success"]:
162
+ return f"Report exported to: {result['report_path']}"
163
+ else:
164
+ return "Failed to export report."
165
+
166
+ # Create the Gradio interface
167
+ def create_interface():
168
+ """Create the Gradio web interface."""
169
+ app = CSVChatApp()
170
+
171
+ with gr.Blocks(title="CSV Chat Assistant") as interface:
172
+ gr.Markdown("# CSV Chat Assistant")
173
+ gr.Markdown("Upload CSV files and ask questions in natural language.")
174
+
175
+ with gr.Row():
176
+ with gr.Column(scale=1):
177
+ file_upload = gr.File(
178
+ label="Upload CSV Files",
179
+ file_count="multiple",
180
+ type="file"
181
+ )
182
+ upload_button = gr.Button("Process Files")
183
+ file_status = gr.Textbox(label="File Status")
184
+
185
+ export_button = gr.Button("Export Conversation")
186
+ export_status = gr.Textbox(label="Export Status")
187
+
188
+ with gr.Column(scale=2):
189
+ chatbot = gr.Chatbot(label="Conversation")
190
+ msg = gr.Textbox(label="Your Question")
191
+ submit_button = gr.Button("Submit")
192
+
193
+ # Set up event handlers
194
+ upload_button.click(
195
+ fn=app.handle_file_upload,
196
+ inputs=[file_upload],
197
+ outputs=[file_status]
198
+ )
199
+
200
+ submit_button.click(
201
+ fn=app.process_query,
202
+ inputs=[msg, chatbot],
203
+ outputs=[chatbot]
204
+ )
205
+
206
+ export_button.click(
207
+ fn=app.export_conversation,
208
+ inputs=[],
209
+ outputs=[export_status]
210
+ )
211
+
212
+ return interface
213
+
214
+ # Launch the app
215
+ if __name__ == "__main__":
216
+ interface = create_interface()
217
+ interface.launch()
indexes/csv_index_builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from pathlib import Path
3
+ import pandas as pd
4
+ from llama_index.readers.file import CSVReader
5
+ from llama_index.schema import Document
6
+
7
+ class EnhancedCSVReader:
8
+ """Enhanced CSV reader with metadata extraction capabilities."""
9
+
10
+ def __init__(self):
11
+ self.csv_reader = CSVReader()
12
+
13
+ def load_data(self, file_path: str) -> List[Document]:
14
+ """Load CSV file and extract documents with metadata."""
15
+ # Load the CSV file
16
+ documents = self.csv_reader.load_data(file_path)
17
+
18
+ # Extract and add metadata
19
+ csv_metadata = self._extract_metadata(file_path)
20
+
21
+ # Enhance documents with metadata
22
+ for doc in documents:
23
+ doc.metadata.update(csv_metadata)
24
+
25
+ return documents
26
+
27
+ def _extract_metadata(self, file_path: str) -> Dict:
28
+ """Extract useful metadata from CSV file."""
29
+ df = pd.read_csv(file_path)
30
+ filename = Path(file_path).name
31
+
32
+ # Extract column information
33
+ columns = df.columns.tolist()
34
+ dtypes = {col: str(df[col].dtype) for col in columns}
35
+
36
+ # Extract sample values (first 3 non-null values per column)
37
+ samples = {}
38
+ for col in columns:
39
+ non_null_values = df[col].dropna().head(3).tolist()
40
+ samples[col] = [str(val) for val in non_null_values]
41
+
42
+ # Basic statistics
43
+ row_count = len(df)
44
+
45
+ return {
46
+ "filename": filename,
47
+ "columns": columns,
48
+ "dtypes": dtypes,
49
+ "samples": samples,
50
+ "row_count": row_count
51
+ }
indexes/index_manager.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from pathlib import Path
3
+ import os
4
+
5
+ from llama_index import VectorStoreIndex, StorageContext
6
+ from llama_index.vector_stores import ChromaVectorStore
7
+ from llama_index.embeddings import HuggingFaceEmbedding
8
+ import chromadb
9
+
10
+ from indexes.csv_index_builder import EnhancedCSVReader
11
+
12
+ class CSVIndexManager:
13
+ """Manages creation and retrieval of indexes for CSV files."""
14
+
15
+ def __init__(self, embedding_model_name: str = "all-MiniLM-L6-v2"):
16
+ self.csv_reader = EnhancedCSVReader()
17
+ self.embed_model = HuggingFaceEmbedding(model_name=embedding_model_name)
18
+ self.chroma_client = chromadb.Client()
19
+ self.indexes = {}
20
+
21
+ def create_index(self, file_path: str) -> VectorStoreIndex:
22
+ """Create vector index for a CSV file."""
23
+ # Extract filename as identifier
24
+ file_id = Path(file_path).stem
25
+
26
+ # Load documents with metadata
27
+ documents = self.csv_reader.load_data(file_path)
28
+
29
+ # Create Chroma collection for this CSV
30
+ chroma_collection = self.chroma_client.create_collection(file_id)
31
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
32
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
33
+
34
+ # Create vector index with our embedding model
35
+ index = VectorStoreIndex.from_documents(
36
+ documents,
37
+ storage_context=storage_context,
38
+ embed_model=self.embed_model
39
+ )
40
+
41
+ # Store in our registry
42
+ self.indexes[file_id] = {
43
+ "index": index,
44
+ "metadata": documents[0].metadata if documents else {}
45
+ }
46
+
47
+ return index
48
+
49
+ def index_directory(self, directory_path: str) -> Dict[str, VectorStoreIndex]:
50
+ """Index all CSV files in a directory."""
51
+ indexed_files = {}
52
+
53
+ # Get all CSV files in directory
54
+ csv_files = [f for f in os.listdir(directory_path)
55
+ if f.lower().endswith('.csv')]
56
+
57
+ # Create index for each CSV file
58
+ for csv_file in csv_files:
59
+ file_path = os.path.join(directory_path, csv_file)
60
+ file_id = Path(file_path).stem
61
+ index = self.create_index(file_path)
62
+ indexed_files[file_id] = index
63
+
64
+ return indexed_files
65
+
66
+ def find_relevant_csvs(self, query: str, top_k: int = 3) -> List[str]:
67
+ """Find most relevant CSV files for a given query."""
68
+ if not self.indexes:
69
+ return []
70
+
71
+ # Create a document from the query
72
+ query_embedding = self.embed_model.get_text_embedding(query)
73
+
74
+ # Calculate similarity with each CSV's metadata
75
+ similarities = {}
76
+ for file_id, index_info in self.indexes.items():
77
+ # Get metadata description
78
+ metadata = index_info["metadata"]
79
+
80
+ # Create a rich description of the CSV
81
+ csv_description = f"CSV file {metadata['filename']} with columns: {', '.join(metadata['columns'])}. "
82
+ csv_description += f"Contains {metadata['row_count']} rows. "
83
+ csv_description += "Sample data: "
84
+ for col, samples in metadata['samples'].items():
85
+ if samples:
86
+ csv_description += f"{col}: {', '.join(str(s) for s in samples[:2])}; "
87
+
88
+ # Get embedding for this description
89
+ csv_embedding = self.embed_model.get_text_embedding(csv_description)
90
+
91
+ # Calculate cosine similarity
92
+ similarity = self._cosine_similarity(query_embedding, csv_embedding)
93
+ similarities[file_id] = similarity
94
+
95
+ # Sort by similarity and return top_k
96
+ sorted_files = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
97
+ return [file_id for file_id, _ in sorted_files[:top_k]]
98
+
99
+ def _cosine_similarity(self, vec1, vec2):
100
+ """Calculate cosine similarity between two vectors."""
101
+ dot_product = sum(a * b for a, b in zip(vec1, vec2))
102
+ norm_a = sum(a * a for a in vec1) ** 0.5
103
+ norm_b = sum(b * b for b in vec2) ** 0.5
104
+ return dot_product / (norm_a * norm_b) if norm_a * norm_b != 0 else 0
indexes/query_engine.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Any
2
+ from llama_index.query_engine import RetrieverQueryEngine
3
+ from llama_index.retrievers import VectorIndexRetriever
4
+ from llama_index.response_synthesizers import ResponseMode
5
+ from llama_index.llms import HuggingFaceLLM
6
+ from llama_index import ServiceContext, QueryBundle
7
+ from llama_index.prompts import PromptTemplate
8
+
9
+ class CSVQueryEngine:
10
+ """Query engine for CSV data with multi-file support."""
11
+
12
+ def __init__(self, index_manager, llm, response_mode="compact"):
13
+ """Initialize with index manager and language model."""
14
+ self.index_manager = index_manager
15
+ self.llm = llm
16
+ self.service_context = ServiceContext.from_defaults(llm=llm)
17
+ self.response_mode = response_mode
18
+
19
+ # Set up custom prompts
20
+ self._setup_prompts()
21
+
22
+ def _setup_prompts(self):
23
+ """Set up custom prompts for CSV querying."""
24
+ self.csv_query_prompt = PromptTemplate(
25
+ """You are an AI assistant specialized in analyzing CSV data.
26
+ Answer the following query using the provided CSV information.
27
+ If calculations are needed, explain your process.
28
+
29
+ CSV Context: {context_str}
30
+ Query: {query_str}
31
+
32
+ Answer:"""
33
+ )
34
+
35
+ def query(self, query_text: str) -> Dict[str, Any]:
36
+ """Process a natural language query across CSV files."""
37
+ # Find relevant CSV files
38
+ relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
39
+
40
+ if not relevant_csvs:
41
+ return {
42
+ "answer": "No relevant CSV files found for your query.",
43
+ "sources": []
44
+ }
45
+
46
+ # Prepare response
47
+ responses = []
48
+ sources = []
49
+
50
+ # Query each relevant CSV
51
+ for csv_id in relevant_csvs:
52
+ index_info = self.index_manager.indexes.get(csv_id)
53
+ if not index_info:
54
+ continue
55
+
56
+ index = index_info["index"]
57
+ metadata = index_info["metadata"]
58
+
59
+ # Create retriever for this index
60
+ retriever = VectorIndexRetriever(
61
+ index=index,
62
+ similarity_top_k=5
63
+ )
64
+
65
+ # Create query engine
66
+ query_engine = RetrieverQueryEngine.from_args(
67
+ retriever=retriever,
68
+ service_context=self.service_context,
69
+ text_qa_template=self.csv_query_prompt,
70
+ response_mode=self.response_mode
71
+ )
72
+
73
+ # Execute query
74
+ response = query_engine.query(query_text)
75
+
76
+ responses.append({
77
+ "csv_id": csv_id,
78
+ "filename": metadata["filename"],
79
+ "response": response
80
+ })
81
+
82
+ # Collect source information
83
+ if hasattr(response, "source_nodes"):
84
+ for node in response.source_nodes:
85
+ sources.append({
86
+ "csv": metadata["filename"],
87
+ "content": node.node.get_content()[:100] + "..."
88
+ })
89
+
90
+ # Combine responses if multiple CSVs were queried
91
+ if len(responses) > 1:
92
+ combined_response = self._combine_responses(query_text, responses)
93
+ return {
94
+ "answer": combined_response,
95
+ "sources": sources
96
+ }
97
+ elif len(responses) == 1:
98
+ return {
99
+ "answer": responses[0]["response"],
100
+ "sources": sources
101
+ }
102
+ else:
103
+ return {
104
+ "answer": "Failed to process query with the available CSV data.",
105
+ "sources": []
106
+ }
107
+
108
+ def _combine_responses(self, query_text: str, responses: List[Dict]) -> str:
109
+ """Combine responses from multiple CSV files."""
110
+ # Create a prompt for combining multiple CSV responses
111
+ combine_prompt = f"""
112
+ I need to answer this question: {query_text}
113
+
114
+ I've analyzed multiple CSV files and found these results:
115
+
116
+ {chr(10).join([f"From {r['filename']}: {str(r['response'])}" for r in responses])}
117
+
118
+ Please provide a unified answer that combines these insights.
119
+ """
120
+
121
+ # Use the LLM to generate a combined response
122
+ combined_response = self.llm.complete(combine_prompt)
123
+
124
+ return combined_response.text
models/llm_setup.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from llama_index.llms import HuggingFaceLLM
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+
6
+ def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct",
7
+ device: str = None,
8
+ context_window: int = 4096,
9
+ max_new_tokens: int = 512) -> HuggingFaceLLM:
10
+ """
11
+ Set up the language model for the CSV chatbot.
12
+
13
+ Args:
14
+ model_name: Name of the Hugging Face model to use
15
+ device: Device to run the model on ('cuda', 'cpu', etc.)
16
+ context_window: Maximum context window size
17
+ max_new_tokens: Maximum number of new tokens to generate
18
+
19
+ Returns:
20
+ Configured LLM instance
21
+ """
22
+ # Determine device
23
+ if device is None:
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ # Configure quantization for memory efficiency
27
+ if device == "cuda":
28
+ quantization_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_compute_dtype=torch.float16
31
+ )
32
+ else:
33
+ quantization_config = None
34
+
35
+ # Configure tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ model_name,
38
+ trust_remote_code=True
39
+ )
40
+
41
+ # Configure model with appropriate parameters for HF Spaces
42
+ model_kwargs = {
43
+ "trust_remote_code": True,
44
+ "torch_dtype": torch.float16,
45
+ }
46
+
47
+ if quantization_config:
48
+ model_kwargs["quantization_config"] = quantization_config
49
+
50
+ # Initialize LLM
51
+ llm = HuggingFaceLLM(
52
+ model_name=model_name,
53
+ tokenizer_name=model_name,
54
+ context_window=context_window,
55
+ max_new_tokens=max_new_tokens,
56
+ generate_kwargs={"temperature": 0.7, "top_p": 0.95},
57
+ device_map=device,
58
+ tokenizer_kwargs={"trust_remote_code": True},
59
+ model_kwargs=model_kwargs,
60
+ # Cache the model to avoid reloading
61
+ cache_folder="./model_cache"
62
+ )
63
+
64
+ return llm
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llama-index
2
+ transformers
3
+ gradio
4
+ pandas
5
+ numpy
6
+ matplotlib
7
+ plotly
8
+ sentence-transformers
9
+ chromadb
10
+ torch
11
+ pillow
12
+ chardet
13
+ bitsandbytes
14
+ accelerate
tools/data_tools.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Callable
2
+ import pandas as pd
3
+ import numpy as np
4
+ from llama_index.tools import FunctionTool
5
+ from pathlib import Path
6
+
7
+ class PandasDataTools:
8
+ """Tools for data analysis operations on CSV files."""
9
+
10
+ def __init__(self, csv_directory: str):
11
+ """Initialize with directory containing CSV files."""
12
+ self.csv_directory = csv_directory
13
+ self.dataframes = {}
14
+ self.tools = self._create_tools()
15
+
16
+ def _load_dataframe(self, filename: str) -> pd.DataFrame:
17
+ """Load a CSV file as DataFrame, with caching."""
18
+ if filename not in self.dataframes:
19
+ file_path = Path(self.csv_directory) / filename
20
+ if not file_path.exists() and not filename.endswith('.csv'):
21
+ file_path = Path(self.csv_directory) / f"{filename}.csv"
22
+
23
+ if file_path.exists():
24
+ self.dataframes[filename] = pd.read_csv(file_path)
25
+ else:
26
+ raise ValueError(f"CSV file not found: {filename}")
27
+
28
+ return self.dataframes[filename]
29
+
30
+ def _create_tools(self) -> List[FunctionTool]:
31
+ """Create LlamaIndex function tools for data operations."""
32
+ tools = [
33
+ FunctionTool.from_defaults(
34
+ name="describe_csv",
35
+ description="Get statistical description of a CSV file",
36
+ fn=self.describe_csv
37
+ ),
38
+ FunctionTool.from_defaults(
39
+ name="filter_data",
40
+ description="Filter CSV data based on conditions",
41
+ fn=self.filter_data
42
+ ),
43
+ FunctionTool.from_defaults(
44
+ name="group_and_aggregate",
45
+ description="Group data and calculate aggregate statistics",
46
+ fn=self.group_and_aggregate
47
+ ),
48
+ FunctionTool.from_defaults(
49
+ name="sort_data",
50
+ description="Sort data by specified columns",
51
+ fn=self.sort_data
52
+ ),
53
+ FunctionTool.from_defaults(
54
+ name="calculate_correlation",
55
+ description="Calculate correlation between columns",
56
+ fn=self.calculate_correlation
57
+ )
58
+ ]
59
+ return tools
60
+
61
+ def get_tools(self) -> List[FunctionTool]:
62
+ """Get all available data tools."""
63
+ return self.tools
64
+
65
+ # Tool implementations
66
+ def describe_csv(self, filename: str) -> Dict[str, Any]:
67
+ """Get statistical description of CSV data."""
68
+ df = self._load_dataframe(filename)
69
+ description = df.describe().to_dict()
70
+
71
+ # Add additional info
72
+ result = {
73
+ "statistics": description,
74
+ "shape": df.shape,
75
+ "columns": df.columns.tolist(),
76
+ "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
77
+ }
78
+
79
+ return result
80
+
81
+ def filter_data(self, filename: str, column: str, condition: str, value: Any) -> Dict[str, Any]:
82
+ """Filter data based on condition (==, >, <, >=, <=, !=, contains)."""
83
+ df = self._load_dataframe(filename)
84
+
85
+ if condition == "==":
86
+ filtered = df[df[column] == value]
87
+ elif condition == ">":
88
+ filtered = df[df[column] > float(value)]
89
+ elif condition == "<":
90
+ filtered = df[df[column] < float(value)]
91
+ elif condition == ">=":
92
+ filtered = df[df[column] >= float(value)]
93
+ elif condition == "<=":
94
+ filtered = df[df[column] <= float(value)]
95
+ elif condition == "!=":
96
+ filtered = df[df[column] != value]
97
+ elif condition.lower() == "contains":
98
+ filtered = df[df[column].astype(str).str.contains(str(value))]
99
+ else:
100
+ return {"error": f"Unsupported condition: {condition}"}
101
+
102
+ return {
103
+ "result_count": len(filtered),
104
+ "results": filtered.head(10).to_dict(orient="records"),
105
+ "total_count": len(df)
106
+ }
107
+
108
+ def group_and_aggregate(self, filename: str, group_by: str, agg_column: str,
109
+ agg_function: str = "mean") -> Dict[str, Any]:
110
+ """Group by column and calculate aggregate statistic."""
111
+ df = self._load_dataframe(filename)
112
+
113
+ agg_functions = {
114
+ "mean": np.mean,
115
+ "sum": np.sum,
116
+ "min": np.min,
117
+ "max": np.max,
118
+ "count": len,
119
+ "median": np.median
120
+ }
121
+
122
+ if agg_function not in agg_functions:
123
+ return {"error": f"Unsupported aggregation function: {agg_function}"}
124
+
125
+ grouped = df.groupby(group_by)[agg_column].agg(agg_functions[agg_function])
126
+
127
+ return {
128
+ "group_by": group_by,
129
+ "aggregated_column": agg_column,
130
+ "aggregation": agg_function,
131
+ "results": grouped.to_dict()
132
+ }
133
+
134
+ def sort_data(self, filename: str, sort_by: str, ascending: bool = True) -> Dict[str, Any]:
135
+ """Sort data by column."""
136
+ df = self._load_dataframe(filename)
137
+
138
+ sorted_df = df.sort_values(by=sort_by, ascending=ascending)
139
+
140
+ return {
141
+ "sorted_by": sort_by,
142
+ "ascending": ascending,
143
+ "results": sorted_df.head(10).to_dict(orient="records")
144
+ }
145
+
146
+ def calculate_correlation(self, filename: str, column1: str, column2: str) -> Dict[str, Any]:
147
+ """Calculate correlation between two columns."""
148
+ df = self._load_dataframe(filename)
149
+
150
+ try:
151
+ correlation = df[column1].corr(df[column2])
152
+ return {
153
+ "correlation": correlation,
154
+ "column1": column1,
155
+ "column2": column2
156
+ }
157
+ except Exception as e:
158
+ return {"error": f"Could not calculate correlation: {str(e)}"}
tools/export.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Union
2
+ import pandas as pd
3
+ import smtplib
4
+ from email.mime.multipart import MIMEMultipart
5
+ from email.mime.text import MIMEText
6
+ from email.mime.image import MIMEImage
7
+ import base64
8
+ import io
9
+ from pathlib import Path
10
+ import json
11
+ import datetime
12
+ from llama_index.tools import FunctionTool
13
+
14
+ class ExportTools:
15
+ """Tools for exporting data, generating reports, and sending emails."""
16
+
17
+ def __init__(self, output_directory: str = "./exports"):
18
+ """Initialize with directory for saved exports."""
19
+ self.output_directory = Path(output_directory)
20
+ self.output_directory.mkdir(exist_ok=True, parents=True)
21
+ self.tools = self._create_tools()
22
+
23
+ def _create_tools(self) -> List[FunctionTool]:
24
+ """Create LlamaIndex function tools for export operations."""
25
+ tools = [
26
+ FunctionTool.from_defaults(
27
+ name="generate_report",
28
+ description="Generate a report from conversation and results",
29
+ fn=self.generate_report
30
+ ),
31
+ FunctionTool.from_defaults(
32
+ name="save_results_to_csv",
33
+ description="Save query results to a CSV file",
34
+ fn=self.save_results_to_csv
35
+ ),
36
+ FunctionTool.from_defaults(
37
+ name="send_email",
38
+ description="Send results via email",
39
+ fn=self.send_email
40
+ )
41
+ ]
42
+ return tools
43
+
44
+ def get_tools(self) -> List[FunctionTool]:
45
+ """Get all available export tools."""
46
+ return self.tools
47
+
48
+ def generate_report(self, title: str, content: str,
49
+ images: List[str] = None) -> Dict[str, Any]:
50
+ """Generate HTML report from content and images."""
51
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
52
+ filename = f"{title.replace(' ', '_')}_{timestamp}.html"
53
+ file_path = self.output_directory / filename
54
+
55
+ # Basic HTML template
56
+ html = f"""
57
+ <!DOCTYPE html>
58
+ <html>
59
+ <head>
60
+ <title>{title}</title>
61
+ <style>
62
+ body {{ font-family: Arial, sans-serif; margin: 40px; line-height: 1.6; }}
63
+ h1 {{ color: #333366; }}
64
+ .report-container {{ max-width: 800px; margin: 0 auto; }}
65
+ .timestamp {{ color: #666; font-size: 0.8em; }}
66
+ img {{ max-width: 100%; height: auto; margin: 20px 0; }}
67
+ </style>
68
+ </head>
69
+ <body>
70
+ <div class="report-container">
71
+ <h1>{title}</h1>
72
+ <div class="timestamp">Generated on: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</div>
73
+ <div class="content">
74
+ {content.replace('\n', '<br>')}
75
+ </div>
76
+ """
77
+
78
+ # Add images if provided
79
+ if images and len(images) > 0:
80
+ html += "<div class='images'>"
81
+ for i, img_base64 in enumerate(images):
82
+ html += f"<img src='data:image/png;base64,{img_base64}' alt='Figure {i+1}'>"
83
+ html += "</div>"
84
+
85
+ html += """
86
+ </div>
87
+ </body>
88
+ </html>
89
+ """
90
+
91
+ # Write to file
92
+ with open(file_path, "w", encoding="utf-8") as f:
93
+ f.write(html)
94
+
95
+ return {
96
+ "success": True,
97
+ "report_path": str(file_path),
98
+ "title": title,
99
+ "timestamp": timestamp
100
+ }
101
+
102
+ def save_results_to_csv(self, data: List[Dict[str, Any]],
103
+ filename: str = None) -> Dict[str, Any]:
104
+ """Save query results to a CSV file."""
105
+ if not data or len(data) == 0:
106
+ return {"success": False, "error": "No data provided"}
107
+
108
+ # Create DataFrame from data
109
+ df = pd.DataFrame(data)
110
+
111
+ # Generate filename if not provided
112
+ if not filename:
113
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
114
+ filename = f"query_results_{timestamp}.csv"
115
+
116
+ # Ensure filename has .csv extension
117
+ if not filename.lower().endswith('.csv'):
118
+ filename += '.csv'
119
+
120
+ file_path = self.output_directory / filename
121
+
122
+ # Save to CSV
123
+ df.to_csv(file_path, index=False)
124
+
125
+ return {
126
+ "success": True,
127
+ "file_path": str(file_path),
128
+ "row_count": len(df),
129
+ "column_count": len(df.columns)
130
+ }
131
+
132
+ def send_email(self, to_email: str, subject: str, body: str,
133
+ from_email: str = None, smtp_server: str = None,
134
+ smtp_port: int = 587, username: str = None,
135
+ password: str = None, images: List[str] = None) -> Dict[str, Any]:
136
+ """
137
+ Send email with results.
138
+ Note: In production, credentials should be securely managed.
139
+ For demo purposes, this will log the email content instead.
140
+ """
141
+ # For safety in a demo app, don't actually send emails
142
+ # Just log what would be sent and return success
143
+
144
+ email_content = {
145
+ "to": to_email,
146
+ "subject": subject,
147
+ "body": body[:100] + "..." if len(body) > 100 else body,
148
+ "images": f"{len(images) if images else 0} images would be attached",
149
+ "note": "Email sending is simulated for demo purposes"
150
+ }
151
+
152
+ # Log the email content
153
+ print(f"SIMULATED EMAIL: {json.dumps(email_content, indent=2)}")
154
+
155
+ return {
156
+ "success": True,
157
+ "to": to_email,
158
+ "subject": subject,
159
+ "simulated": True,
160
+ "timestamp": datetime.datetime.now().isoformat()
161
+ }
tools/visualization.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Tuple, Union
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ import io
6
+ import base64
7
+ import numpy as np
8
+ from llama_index.tools import FunctionTool
9
+ from pathlib import Path
10
+
11
+ # Configure matplotlib for non-interactive environments
12
+ matplotlib.use('Agg')
13
+
14
+ class VisualizationTools:
15
+ """Tools for creating visualizations from CSV data."""
16
+
17
+ def __init__(self, csv_directory: str):
18
+ """Initialize with directory containing CSV files."""
19
+ self.csv_directory = csv_directory
20
+ self.dataframes = {}
21
+ self.tools = self._create_tools()
22
+ self.figure_size = (10, 6)
23
+ self.dpi = 100
24
+
25
+ def _load_dataframe(self, filename: str) -> pd.DataFrame:
26
+ """Load a CSV file as DataFrame, with caching."""
27
+ if filename not in self.dataframes:
28
+ file_path = Path(self.csv_directory) / filename
29
+ if not file_path.exists() and not filename.endswith('.csv'):
30
+ file_path = Path(self.csv_directory) / f"{filename}.csv"
31
+
32
+ if file_path.exists():
33
+ self.dataframes[filename] = pd.read_csv(file_path)
34
+ else:
35
+ raise ValueError(f"CSV file not found: {filename}")
36
+
37
+ return self.dataframes[filename]
38
+
39
+ def _create_tools(self) -> List[FunctionTool]:
40
+ """Create LlamaIndex function tools for visualizations."""
41
+ tools = [
42
+ FunctionTool.from_defaults(
43
+ name="create_line_chart",
44
+ description="Create a line chart from CSV data",
45
+ fn=self.create_line_chart
46
+ ),
47
+ FunctionTool.from_defaults(
48
+ name="create_bar_chart",
49
+ description="Create a bar chart from CSV data",
50
+ fn=self.create_bar_chart
51
+ ),
52
+ FunctionTool.from_defaults(
53
+ name="create_scatter_plot",
54
+ description="Create a scatter plot from CSV data",
55
+ fn=self.create_scatter_plot
56
+ ),
57
+ FunctionTool.from_defaults(
58
+ name="create_histogram",
59
+ description="Create a histogram from CSV data",
60
+ fn=self.create_histogram
61
+ ),
62
+ FunctionTool.from_defaults(
63
+ name="create_pie_chart",
64
+ description="Create a pie chart from CSV data",
65
+ fn=self.create_pie_chart
66
+ )
67
+ ]
68
+ return tools
69
+
70
+ def get_tools(self) -> List[FunctionTool]:
71
+ """Get all available visualization tools."""
72
+ return self.tools
73
+
74
+ def _figure_to_base64(self, fig) -> str:
75
+ """Convert matplotlib figure to base64 encoded string."""
76
+ buf = io.BytesIO()
77
+ fig.savefig(buf, format='png', dpi=self.dpi)
78
+ buf.seek(0)
79
+ img_str = base64.b64encode(buf.read()).decode('utf-8')
80
+ plt.close(fig)
81
+ return img_str
82
+
83
+ # Visualization tool implementations
84
+ def create_line_chart(self, filename: str, x_column: str, y_column: str,
85
+ title: str = None, limit: int = 50) -> Dict[str, Any]:
86
+ """Create a line chart visualization."""
87
+ df = self._load_dataframe(filename)
88
+
89
+ # Limit data points if needed
90
+ if len(df) > limit:
91
+ df = df.head(limit)
92
+
93
+ fig, ax = plt.subplots(figsize=self.figure_size)
94
+
95
+ # Create line chart
96
+ ax.plot(df[x_column], df[y_column], marker='o', linestyle='-')
97
+
98
+ # Set labels and title
99
+ ax.set_xlabel(x_column)
100
+ ax.set_ylabel(y_column)
101
+ ax.set_title(title or f"{y_column} vs {x_column}")
102
+ ax.grid(True)
103
+
104
+ # Convert to base64
105
+ img_str = self._figure_to_base64(fig)
106
+
107
+ return {
108
+ "chart_type": "line",
109
+ "x_column": x_column,
110
+ "y_column": y_column,
111
+ "data_points": len(df),
112
+ "image": img_str
113
+ }
114
+
115
+ def create_bar_chart(self, filename: str, x_column: str, y_column: str,
116
+ title: str = None, limit: int = 20) -> Dict[str, Any]:
117
+ """Create a bar chart visualization."""
118
+ df = self._load_dataframe(filename)
119
+
120
+ # Limit categories if needed
121
+ if len(df) > limit:
122
+ df = df.head(limit)
123
+
124
+ fig, ax = plt.subplots(figsize=self.figure_size)
125
+
126
+ # Create bar chart
127
+ ax.bar(df[x_column], df[y_column])
128
+
129
+ # Set labels and title
130
+ ax.set_xlabel(x_column)
131
+ ax.set_ylabel(y_column)
132
+ ax.set_title(title or f"{y_column} by {x_column}")
133
+
134
+ # Rotate x labels if there are many categories
135
+ if len(df) > 5:
136
+ plt.xticks(rotation=45, ha='right')
137
+
138
+ plt.tight_layout()
139
+
140
+ # Convert to base64
141
+ img_str = self._figure_to_base64(fig)
142
+
143
+ return {
144
+ "chart_type": "bar",
145
+ "x_column": x_column,
146
+ "y_column": y_column,
147
+ "categories": len(df),
148
+ "image": img_str
149
+ }
150
+
151
+ def create_scatter_plot(self, filename: str, x_column: str, y_column: str,
152
+ color_column: str = None, title: str = None) -> Dict[str, Any]:
153
+ """Create a scatter plot visualization."""
154
+ df = self._load_dataframe(filename)
155
+
156
+ fig, ax = plt.subplots(figsize=self.figure_size)
157
+
158
+ # Create scatter plot
159
+ if color_column and color_column in df.columns:
160
+ scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7)
161
+ plt.colorbar(scatter, ax=ax, label=color_column)
162
+ else:
163
+ ax.scatter(df[x_column], df[y_column], alpha=0.7)
164
+
165
+ # Set labels and title
166
+ ax.set_xlabel(x_column)
167
+ ax.set_ylabel(y_column)
168
+ ax.set_title(title or f"{y_column} vs {x_column}")
169
+ ax.grid(True, linestyle='--', alpha=0.7)
170
+
171
+ # Convert to base64
172
+ img_str = self._figure_to_base64(fig)
173
+
174
+ return {
175
+ "chart_type": "scatter",
176
+ "x_column": x_column,
177
+ "y_column": y_column,
178
+ "color_column": color_column,
179
+ "data_points": len(df),
180
+ "image": img_str
181
+ }
182
+
183
+ def create_histogram(self, filename: str, column: str, bins: int = 10,
184
+ title: str = None) -> Dict[str, Any]:
185
+ """Create a histogram visualization."""
186
+ df = self._load_dataframe(filename)
187
+
188
+ fig, ax = plt.subplots(figsize=self.figure_size)
189
+
190
+ # Create histogram
191
+ ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black')
192
+
193
+ # Set labels and title
194
+ ax.set_xlabel(column)
195
+ ax.set_ylabel('Frequency')
196
+ ax.set_title(title or f"Distribution of {column}")
197
+ ax.grid(True, linestyle='--', alpha=0.7)
198
+
199
+ # Convert to base64
200
+ img_str = self._figure_to_base64(fig)
201
+
202
+ return {
203
+ "chart_type": "histogram",
204
+ "column": column,
205
+ "bins": bins,
206
+ "data_points": len(df),
207
+ "image": img_str
208
+ }
209
+
210
+ def create_pie_chart(self, filename: str, label_column: str, value_column: str = None,
211
+ title: str = None, limit: int = 10) -> Dict[str, Any]:
212
+ """Create a pie chart visualization."""
213
+ df = self._load_dataframe(filename)
214
+
215
+ # If value column not provided, count occurrences of each label
216
+ if value_column is None:
217
+ data = df[label_column].value_counts().head(limit)
218
+ labels = data.index.tolist()
219
+ values = data.values.tolist()
220
+ else:
221
+ # Group by label and sum values
222
+ grouped = df.groupby(label_column)[value_column].sum().reset_index()
223
+ # Limit to top categories
224
+ grouped = grouped.nlargest(limit, value_column)
225
+ labels = grouped[label_column].tolist()
226
+ values = grouped[value_column].tolist()
227
+
228
+ fig, ax = plt.subplots(figsize=self.figure_size)
229
+
230
+ # Create pie chart
231
+ ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True)
232
+ ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle
233
+
234
+ # Set title
235
+ ax.set_title(title or f"Distribution of {label_column}")
236
+
237
+ # Convert to base64
238
+ img_str = self._figure_to_base64(fig)
239
+
240
+ return {
241
+ "chart_type": "pie",
242
+ "label_column": label_column,
243
+ "value_column": value_column,
244
+ "categories": len(labels),
245
+ "image": img_str
246
+ }
utils/csv_helper.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Tuple
2
+ import pandas as pd
3
+ import numpy as np
4
+ from pathlib import Path
5
+ import os
6
+ import chardet
7
+ import csv
8
+
9
+ class CSVHelpers:
10
+ """Helper utilities for CSV preprocessing and analysis."""
11
+
12
+ @staticmethod
13
+ def detect_encoding(file_path: str, sample_size: int = 10000) -> str:
14
+ """Detect the encoding of a CSV file."""
15
+ with open(file_path, 'rb') as f:
16
+ raw_data = f.read(sample_size)
17
+ result = chardet.detect(raw_data)
18
+ return result['encoding']
19
+
20
+ @staticmethod
21
+ def detect_delimiter(file_path: str, encoding: str = 'utf-8') -> str:
22
+ """Detect the delimiter used in a CSV file."""
23
+ with open(file_path, 'r', encoding=encoding) as csvfile:
24
+ sample = csvfile.read(4096)
25
+
26
+ # Check common delimiters
27
+ for delimiter in [',', ';', '\t', '|']:
28
+ sniffer = csv.Sniffer()
29
+ try:
30
+ if delimiter in sample:
31
+ dialect = sniffer.sniff(sample, delimiters=delimiter)
32
+ return dialect.delimiter
33
+ except:
34
+ continue
35
+
36
+ # Default to comma if detection fails
37
+ return ','
38
+
39
+ @staticmethod
40
+ def preprocess_csv(file_path: str) -> Tuple[pd.DataFrame, Dict[str, Any]]:
41
+ """
42
+ Preprocess a CSV file with automatic encoding and delimiter detection.
43
+ Returns the DataFrame and metadata about the preprocessing.
44
+ """
45
+ # Detect encoding
46
+ try:
47
+ encoding = CSVHelpers.detect_encoding(file_path)
48
+ except:
49
+ encoding = 'utf-8' # Default to UTF-8 if detection fails
50
+
51
+ # Detect delimiter
52
+ try:
53
+ delimiter = CSVHelpers.detect_delimiter(file_path, encoding)
54
+ except:
55
+ delimiter = ',' # Default to comma if detection fails
56
+
57
+ # Read the CSV with detected parameters
58
+ df = pd.read_csv(file_path, encoding=encoding, delimiter=delimiter, low_memory=False)
59
+
60
+ # Basic preprocessing
61
+ metadata = {
62
+ "original_shape": df.shape,
63
+ "encoding": encoding,
64
+ "delimiter": delimiter,
65
+ "columns": list(df.columns),
66
+ "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
67
+ }
68
+
69
+ # Handle missing values
70
+ missing_counts = df.isna().sum()
71
+ metadata["missing_values"] = {col: int(count) for col, count in missing_counts.items() if count > 0}
72
+
73
+ # Handle duplicate rows
74
+ duplicates = df.duplicated().sum()
75
+ metadata["duplicate_rows"] = int(duplicates)
76
+
77
+ return df, metadata
78
+
79
+ @staticmethod
80
+ def infer_column_types(df: pd.DataFrame) -> Dict[str, str]:
81
+ """
82
+ Infer semantic types of columns (beyond pandas dtypes).
83
+ Examples: date, categorical, numeric, text, etc.
84
+ """
85
+ column_types = {}
86
+
87
+ for column in df.columns:
88
+ # Skip columns with all missing values
89
+ if df[column].isna().all():
90
+ column_types[column] = "unknown"
91
+ continue
92
+
93
+ # Get pandas dtype
94
+ dtype = df[column].dtype
95
+
96
+ # Check if datetime
97
+ if pd.api.types.is_datetime64_dtype(df[column]):
98
+ column_types[column] = "datetime"
99
+
100
+ # Try to convert to datetime if string
101
+ elif dtype == 'object':
102
+ try:
103
+ # Sample non-null values
104
+ sample = df[column].dropna().head(10)
105
+ pd.to_datetime(sample)
106
+ column_types[column] = "potential_datetime"
107
+ except:
108
+ # Check if categorical (few unique values)
109
+ unique_ratio = df[column].nunique() / len(df)
110
+ if unique_ratio < 0.1: # Less than 10% unique values
111
+ column_types[column] = "categorical"
112
+ else:
113
+ column_types[column] = "text"
114
+
115
+ # Numeric types
116
+ elif pd.api.types.is_numeric_dtype(dtype):
117
+ # Check if potential ID column
118
+ if df[column].nunique() == len(df) and df[column].min() >= 0:
119
+ column_types[column] = "id"
120
+ # Check if binary
121
+ elif df[column].nunique() <= 2:
122
+ column_types[column] = "binary"
123
+ # Check if integer
124
+ elif pd.api.types.is_integer_dtype(dtype):
125
+ column_types[column] = "integer"
126
+ else:
127
+ column_types[column] = "float"
128
+
129
+ # Boolean type
130
+ elif pd.api.types.is_bool_dtype(dtype):
131
+ column_types[column] = "boolean"
132
+
133
+ # Fallback
134
+ else:
135
+ column_types[column] = "unknown"
136
+
137
+ return column_types
138
+
139
+ @staticmethod
140
+ def suggest_visualizations(df: pd.DataFrame) -> List[Dict[str, Any]]:
141
+ """
142
+ Suggest appropriate visualizations based on data types.
143
+ Returns a list of visualization suggestions.
144
+ """
145
+ suggestions = []
146
+ column_types = CSVHelpers.infer_column_types(df)
147
+ numeric_columns = [col for col, type in column_types.items()
148
+ if type in ["integer", "float"]]
149
+ categorical_columns = [col for col, type in column_types.items()
150
+ if type in ["categorical", "binary"]]
151
+ datetime_columns = [col for col, type in column_types.items()
152
+ if type in ["datetime", "potential_datetime"]]
153
+
154
+ # Histogram for numeric columns
155
+ for col in numeric_columns[:3]: # Limit to first 3
156
+ suggestions.append({
157
+ "chart_type": "histogram",
158
+ "column": col,
159
+ "title": f"Distribution of {col}"
160
+ })
161
+
162
+ # Bar charts for categorical columns
163
+ for col in categorical_columns[:3]: # Limit to first 3
164
+ suggestions.append({
165
+ "chart_type": "bar",
166
+ "x_column": col,
167
+ "y_column": "count",
168
+ "title": f"Count by {col}"
169
+ })
170
+
171
+ # Time series for datetime + numeric combinations
172
+ if datetime_columns and numeric_columns:
173
+ suggestions.append({
174
+ "chart_type": "line",
175
+ "x_column": datetime_columns[0],
176
+ "y_column": numeric_columns[0],
177
+ "title": f"{numeric_columns[0]} over Time"
178
+ })
179
+
180
+ # Scatter plots for numeric pairs
181
+ if len(numeric_columns) >= 2:
182
+ suggestions.append({
183
+ "chart_type": "scatter",
184
+ "x_column": numeric_columns[0],
185
+ "y_column": numeric_columns[1],
186
+ "title": f"{numeric_columns[1]} vs {numeric_columns[0]}"
187
+ })
188
+
189
+ return suggestions
utils/prompt_template.py ADDED
File without changes