ChatCSV / indexes /index_manager_bk.py
Chamin09's picture
Rename indexes/index_manager.py to indexes/index_manager_bk.py
e135305 verified
raw
history blame
4.38 kB
from typing import Dict, List, Optional
from pathlib import Path
import os
from llama_index import VectorStoreIndex, StorageContext
from llama_index.vector_stores import ChromaVectorStore
from llama_index.embeddings import HuggingFaceEmbedding
import chromadb
from indexes.csv_index_builder import EnhancedCSVReader
class CSVIndexManager:
"""Manages creation and retrieval of indexes for CSV files."""
def __init__(self, embedding_model_name: str = "all-MiniLM-L6-v2"):
self.csv_reader = EnhancedCSVReader()
self.embed_model = HuggingFaceEmbedding(model_name=embedding_model_name)
self.chroma_client = chromadb.Client()
self.indexes = {}
def create_index(self, file_path: str) -> VectorStoreIndex:
"""Create vector index for a CSV file."""
# Extract filename as identifier
file_id = Path(file_path).stem
# Load documents with metadata
documents = self.csv_reader.load_data(file_path)
# Create Chroma collection for this CSV
chroma_collection = self.chroma_client.create_collection(file_id)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create vector index with our embedding model
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
embed_model=self.embed_model
)
# Store in our registry
self.indexes[file_id] = {
"index": index,
"metadata": documents[0].metadata if documents else {}
}
return index
def index_directory(self, directory_path: str) -> Dict[str, VectorStoreIndex]:
"""Index all CSV files in a directory."""
indexed_files = {}
# Get all CSV files in directory
csv_files = [f for f in os.listdir(directory_path)
if f.lower().endswith('.csv')]
# Create index for each CSV file
for csv_file in csv_files:
file_path = os.path.join(directory_path, csv_file)
file_id = Path(file_path).stem
index = self.create_index(file_path)
indexed_files[file_id] = index
return indexed_files
def find_relevant_csvs(self, query: str, top_k: int = 3) -> List[str]:
"""Find most relevant CSV files for a given query."""
if not self.indexes:
return []
# Create a document from the query
query_embedding = self.embed_model.get_text_embedding(query)
# Calculate similarity with each CSV's metadata
similarities = {}
for file_id, index_info in self.indexes.items():
# Get metadata description
metadata = index_info["metadata"]
# Create a rich description of the CSV
csv_description = f"CSV file {metadata['filename']} with columns: {', '.join(metadata['columns'])}. "
csv_description += f"Contains {metadata['row_count']} rows. "
csv_description += "Sample data: "
for col, samples in metadata['samples'].items():
if samples:
csv_description += f"{col}: {', '.join(str(s) for s in samples[:2])}; "
# Get embedding for this description
csv_embedding = self.embed_model.get_text_embedding(csv_description)
# Calculate cosine similarity
similarity = self._cosine_similarity(query_embedding, csv_embedding)
similarities[file_id] = similarity
# Sort by similarity and return top_k
sorted_files = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
return [file_id for file_id, _ in sorted_files[:top_k]]
def _cosine_similarity(self, vec1, vec2):
"""Calculate cosine similarity between two vectors."""
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm_a = sum(a * a for a in vec1) ** 0.5
norm_b = sum(b * b for b in vec2) ** 0.5
return dot_product / (norm_a * norm_b) if norm_a * norm_b != 0 else 0