ChatCSV / app.py
Chamin09's picture
Upload 12 files
e13d87a verified
raw
history blame
7.71 kB
import os
import gradio as gr
import tempfile
from pathlib import Path
import base64
from PIL import Image
import io
import time
# Import our components
from models.llm_setup import setup_llm
from indexes.csv_index_builder import EnhancedCSVReader
from indexes.index_manager import CSVIndexManager
from indexes.query_engine import CSVQueryEngine
from tools.data_tools import PandasDataTools
from tools.visualization import VisualizationTools
from tools.export import ExportTools
# Setup temporary directory for uploaded files
UPLOAD_DIR = Path(tempfile.mkdtemp())
EXPORT_DIR = Path(tempfile.mkdtemp())
class CSVChatApp:
"""Main application class for CSV chatbot."""
def __init__(self):
"""Initialize the application components."""
# Initialize the language model
self.llm = setup_llm()
# Initialize the index manager
self.index_manager = CSVIndexManager()
# Initialize tools
self.data_tools = PandasDataTools(str(UPLOAD_DIR))
self.viz_tools = VisualizationTools(str(UPLOAD_DIR))
self.export_tools = ExportTools(str(EXPORT_DIR))
# Initialize query engine with tools
self.query_engine = self._setup_query_engine()
# Track conversation history
self.chat_history = []
self.uploaded_files = []
def _setup_query_engine(self):
"""Set up the query engine with tools."""
# Get all tools
tools = (
self.data_tools.get_tools() +
self.viz_tools.get_tools() +
self.export_tools.get_tools()
)
# Create query engine with tools
query_engine = CSVQueryEngine(self.index_manager, self.llm)
return query_engine
def handle_file_upload(self, files):
"""Process uploaded CSV files."""
file_info = []
for file in files:
if file is None:
continue
# Get file path
file_path = Path(file.name)
# Only process CSV files
if not file_path.suffix.lower() == '.csv':
continue
# Copy to upload directory
dest_path = UPLOAD_DIR / file_path.name
with open(dest_path, 'wb') as f:
f.write(file_path.read_bytes())
# Create index for this file
try:
self.index_manager.create_index(str(dest_path))
file_info.append(f"βœ… Indexed: {file_path.name}")
self.uploaded_files.append(str(dest_path))
except Exception as e:
file_info.append(f"❌ Failed to index {file_path.name}: {str(e)}")
# Return information about processed files
if file_info:
return "\n".join(file_info)
else:
return "No CSV files were uploaded."
def process_query(self, query, history):
"""Process a user query and generate a response."""
if not self.uploaded_files:
return "Please upload CSV files before asking questions."
# Add user message to history
self.chat_history.append({"role": "user", "content": query})
# Process the query
try:
response = self.query_engine.query(query)
answer = response["answer"]
# Check if response contains an image
if isinstance(answer, dict) and "image" in answer:
# Handle image in response
img_data = answer["image"]
img = Image.open(io.BytesIO(base64.b64decode(img_data)))
img_path = EXPORT_DIR / f"viz_{int(time.time())}.png"
img.save(img_path)
# Update answer to include image path
text_response = answer.get("text", "Generated visualization")
answer = (text_response, str(img_path))
# Add assistant message to history
self.chat_history.append({"role": "assistant", "content": answer})
return answer
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
self.chat_history.append({"role": "assistant", "content": error_msg})
return error_msg
def export_conversation(self):
"""Export the conversation as a report."""
if not self.chat_history:
return "No conversation to export."
# Extract content for report
title = "CSV Chat Conversation Report"
content = ""
images = []
for msg in self.chat_history:
role = msg["role"]
content_text = msg["content"]
# Handle content that might contain images
if isinstance(content_text, tuple) and len(content_text) == 2:
text, img_path = content_text
content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {text}"
# Add image to report
try:
with open(img_path, "rb") as img_file:
img_data = base64.b64encode(img_file.read()).decode('utf-8')
images.append(img_data)
except Exception:
pass
else:
content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {content_text}"
# Generate report
result = self.export_tools.generate_report(title, content, images)
if result["success"]:
return f"Report exported to: {result['report_path']}"
else:
return "Failed to export report."
# Create the Gradio interface
def create_interface():
"""Create the Gradio web interface."""
app = CSVChatApp()
with gr.Blocks(title="CSV Chat Assistant") as interface:
gr.Markdown("# CSV Chat Assistant")
gr.Markdown("Upload CSV files and ask questions in natural language.")
with gr.Row():
with gr.Column(scale=1):
file_upload = gr.File(
label="Upload CSV Files",
file_count="multiple",
type="file"
)
upload_button = gr.Button("Process Files")
file_status = gr.Textbox(label="File Status")
export_button = gr.Button("Export Conversation")
export_status = gr.Textbox(label="Export Status")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Conversation")
msg = gr.Textbox(label="Your Question")
submit_button = gr.Button("Submit")
# Set up event handlers
upload_button.click(
fn=app.handle_file_upload,
inputs=[file_upload],
outputs=[file_status]
)
submit_button.click(
fn=app.process_query,
inputs=[msg, chatbot],
outputs=[chatbot]
)
export_button.click(
fn=app.export_conversation,
inputs=[],
outputs=[export_status]
)
return interface
# Launch the app
if __name__ == "__main__":
interface = create_interface()
interface.launch()