import os import pickle import json import numpy as np from typing import List, Dict, Any, Optional, Tuple import faiss from tqdm import tqdm from sentence_transformers import SentenceTransformer, CrossEncoder class VectorStore: def __init__(self, embedding_dir: str = "data/embeddings", model_name: str = "BAAI/bge-small-en-v1.5", reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): self.embedding_dir = embedding_dir self.index = None self.chunk_ids = [] self.chunks = {} self.model = SentenceTransformer(model_name) self.reranker = CrossEncoder(reranker_name) self.load_or_create_index() def load_or_create_index(self) -> None: index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl') if os.path.exists(index_path): with open(index_path, 'rb') as f: data = pickle.load(f) self.index = data['index'] self.chunk_ids = data['chunk_ids'] self.chunks = data['chunks'] print(f"Loaded existing index with {len(self.chunk_ids)} chunks") else: embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') if os.path.exists(embeddings_path): self.create_index() else: print("No embeddings found. Please run the chunker first.") def create_index(self) -> None: """Create FAISS index from embeddings.""" embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') with open(embeddings_path, 'rb') as f: embedding_map = pickle.load(f) chunk_ids = list(embedding_map.keys()) embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids]) chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids} dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings.astype(np.float32)) self.index = index self.chunk_ids = chunk_ids self.chunks = chunks with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f: pickle.dump({ 'index': index, 'chunk_ids': chunk_ids, 'chunks': chunks }, f) print(f"Created index with {len(chunk_ids)} chunks") def search(self, query: str, k: int = 5, filter_categories: Optional[List[str]] = None, rerank: bool = True) -> List[Dict[str, Any]]: if self.index is None: print("No index available. Please create an index first.") return [] query_embedding = self.model.encode([query])[0] D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids))) results = [] for i, idx in enumerate(I[0]): chunk_id = self.chunk_ids[idx] chunk = self.chunks[chunk_id] if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): continue result = { 'chunk_id': chunk_id, 'score': float(D[0][i]), 'chunk': chunk } results.append(result) if rerank and results: pairs = [(query, result['chunk']['content']) for result in results] rerank_scores = self.reranker.predict(pairs) for i, score in enumerate(rerank_scores): results[i]['rerank_score'] = float(score) results = sorted(results, key=lambda x: x['rerank_score'], reverse=True) results = results[:k] return results def hybrid_search(self, query: str, k: int = 5, filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]: vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False) keywords = query.lower().split() keyword_scores = {} for chunk_id, chunk_data in self.chunks.items(): chunk = chunk_data content = (chunk['title'] + " " + chunk['content']).lower() score = sum(content.count(keyword) for keyword in keywords) if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): continue keyword_scores[chunk_id] = score keyword_results = sorted( [{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]} for chunk_id, score in keyword_scores.items() if score > 0], key=lambda x: x['score'], reverse=True )[:k] seen_ids = set() combined_results = [] for result in vector_results: combined_results.append(result) seen_ids.add(result['chunk_id']) for result in keyword_results: if result['chunk_id'] not in seen_ids: combined_results.append(result) seen_ids.add(result['chunk_id']) combined_results = combined_results[:k] if combined_results: pairs = [(query, result['chunk']['content']) for result in combined_results] rerank_scores = self.reranker.predict(pairs) for i, score in enumerate(rerank_scores): combined_results[i]['rerank_score'] = float(score) combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True) return combined_results