|
from typing import Dict, List, Optional
|
|
from pathlib import Path
|
|
import os
|
|
|
|
from llama_index import VectorStoreIndex, StorageContext
|
|
from llama_index.vector_stores import ChromaVectorStore
|
|
from llama_index.embeddings import HuggingFaceEmbedding
|
|
import chromadb
|
|
|
|
from indexes.csv_index_builder import EnhancedCSVReader
|
|
|
|
class CSVIndexManager:
|
|
"""Manages creation and retrieval of indexes for CSV files."""
|
|
|
|
def __init__(self, embedding_model_name: str = "all-MiniLM-L6-v2"):
|
|
self.csv_reader = EnhancedCSVReader()
|
|
self.embed_model = HuggingFaceEmbedding(model_name=embedding_model_name)
|
|
self.chroma_client = chromadb.Client()
|
|
self.indexes = {}
|
|
|
|
def create_index(self, file_path: str) -> VectorStoreIndex:
|
|
"""Create vector index for a CSV file."""
|
|
|
|
file_id = Path(file_path).stem
|
|
|
|
|
|
documents = self.csv_reader.load_data(file_path)
|
|
|
|
|
|
chroma_collection = self.chroma_client.create_collection(file_id)
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
|
|
|
|
index = VectorStoreIndex.from_documents(
|
|
documents,
|
|
storage_context=storage_context,
|
|
embed_model=self.embed_model
|
|
)
|
|
|
|
|
|
self.indexes[file_id] = {
|
|
"index": index,
|
|
"metadata": documents[0].metadata if documents else {}
|
|
}
|
|
|
|
return index
|
|
|
|
def index_directory(self, directory_path: str) -> Dict[str, VectorStoreIndex]:
|
|
"""Index all CSV files in a directory."""
|
|
indexed_files = {}
|
|
|
|
|
|
csv_files = [f for f in os.listdir(directory_path)
|
|
if f.lower().endswith('.csv')]
|
|
|
|
|
|
for csv_file in csv_files:
|
|
file_path = os.path.join(directory_path, csv_file)
|
|
file_id = Path(file_path).stem
|
|
index = self.create_index(file_path)
|
|
indexed_files[file_id] = index
|
|
|
|
return indexed_files
|
|
|
|
def find_relevant_csvs(self, query: str, top_k: int = 3) -> List[str]:
|
|
"""Find most relevant CSV files for a given query."""
|
|
if not self.indexes:
|
|
return []
|
|
|
|
|
|
query_embedding = self.embed_model.get_text_embedding(query)
|
|
|
|
|
|
similarities = {}
|
|
for file_id, index_info in self.indexes.items():
|
|
|
|
metadata = index_info["metadata"]
|
|
|
|
|
|
csv_description = f"CSV file {metadata['filename']} with columns: {', '.join(metadata['columns'])}. "
|
|
csv_description += f"Contains {metadata['row_count']} rows. "
|
|
csv_description += "Sample data: "
|
|
for col, samples in metadata['samples'].items():
|
|
if samples:
|
|
csv_description += f"{col}: {', '.join(str(s) for s in samples[:2])}; "
|
|
|
|
|
|
csv_embedding = self.embed_model.get_text_embedding(csv_description)
|
|
|
|
|
|
similarity = self._cosine_similarity(query_embedding, csv_embedding)
|
|
similarities[file_id] = similarity
|
|
|
|
|
|
sorted_files = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
|
|
return [file_id for file_id, _ in sorted_files[:top_k]]
|
|
|
|
def _cosine_similarity(self, vec1, vec2):
|
|
"""Calculate cosine similarity between two vectors."""
|
|
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
|
norm_a = sum(a * a for a in vec1) ** 0.5
|
|
norm_b = sum(b * b for b in vec2) ** 0.5
|
|
return dot_product / (norm_a * norm_b) if norm_a * norm_b != 0 else 0
|
|
|