import os import asyncio from typing import Optional from .retriever import SearchAPIRetriever, SectionRetriever from langchain.retrievers import ( ContextualCompressionRetriever, ) from langchain.retrievers.document_compressors import ( DocumentCompressorPipeline, EmbeddingsFilter, ) from langchain.text_splitter import RecursiveCharacterTextSplitter from ..vector_store import VectorStoreWrapper from ..utils.costs import estimate_embedding_cost from ..memory.embeddings import OPENAI_EMBEDDING_MODEL class VectorstoreCompressor: def __init__(self, vector_store: VectorStoreWrapper, max_results:int = 7, filter: Optional[dict] = None, **kwargs): self.vector_store = vector_store self.max_results = max_results self.filter = filter self.kwargs = kwargs def __pretty_print_docs(self, docs): return f"\n".join(f"Source: {d.metadata.get('source')}\n" f"Title: {d.metadata.get('title')}\n" f"Content: {d.page_content}\n" for d in docs) async def async_get_context(self, query, max_results=5): """Get relevant context from vector store""" results = await self.vector_store.asimilarity_search(query=query, k=max_results, filter=self.filter) return self.__pretty_print_docs(results) class ContextCompressor: def __init__(self, documents, embeddings, max_results=5, **kwargs): self.max_results = max_results self.documents = documents self.kwargs = kwargs self.embeddings = embeddings self.similarity_threshold = os.environ.get("SIMILARITY_THRESHOLD", 0.35) def __get_contextual_retriever(self): splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) relevance_filter = EmbeddingsFilter(embeddings=self.embeddings, similarity_threshold=self.similarity_threshold) pipeline_compressor = DocumentCompressorPipeline( transformers=[splitter, relevance_filter] ) base_retriever = SearchAPIRetriever( pages=self.documents ) contextual_retriever = ContextualCompressionRetriever( base_compressor=pipeline_compressor, base_retriever=base_retriever ) return contextual_retriever def __pretty_print_docs(self, docs, top_n): return f"\n".join(f"Source: {d.metadata.get('source')}\n" f"Title: {d.metadata.get('title')}\n" f"Content: {d.page_content}\n" for i, d in enumerate(docs) if i < top_n) async def async_get_context(self, query, max_results=5, cost_callback=None): compressed_docs = self.__get_contextual_retriever() if cost_callback: cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents)) relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query) return self.__pretty_print_docs(relevant_docs, max_results) class WrittenContentCompressor: def __init__(self, documents, embeddings, similarity_threshold, **kwargs): self.documents = documents self.kwargs = kwargs self.embeddings = embeddings self.similarity_threshold = similarity_threshold def __get_contextual_retriever(self): splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) relevance_filter = EmbeddingsFilter(embeddings=self.embeddings, similarity_threshold=self.similarity_threshold) pipeline_compressor = DocumentCompressorPipeline( transformers=[splitter, relevance_filter] ) base_retriever = SectionRetriever( sections=self.documents ) contextual_retriever = ContextualCompressionRetriever( base_compressor=pipeline_compressor, base_retriever=base_retriever ) return contextual_retriever def __pretty_docs_list(self, docs, top_n): return [f"Title: {d.metadata.get('section_title')}\nContent: {d.page_content}\n" for i, d in enumerate(docs) if i < top_n] async def async_get_context(self, query, max_results=5, cost_callback=None): compressed_docs = self.__get_contextual_retriever() if cost_callback: cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents)) relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query) return self.__pretty_docs_list(relevant_docs, max_results)