Shreyas094's picture
Upload 528 files
372531f verified
import os
import asyncio
from typing import Optional
from .retriever import SearchAPIRetriever, SectionRetriever
from langchain.retrievers import (
ContextualCompressionRetriever,
)
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline,
EmbeddingsFilter,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from ..vector_store import VectorStoreWrapper
from ..utils.costs import estimate_embedding_cost
from ..memory.embeddings import OPENAI_EMBEDDING_MODEL
class VectorstoreCompressor:
def __init__(self, vector_store: VectorStoreWrapper, max_results:int = 7, filter: Optional[dict] = None, **kwargs):
self.vector_store = vector_store
self.max_results = max_results
self.filter = filter
self.kwargs = kwargs
def __pretty_print_docs(self, docs):
return f"\n".join(f"Source: {d.metadata.get('source')}\n"
f"Title: {d.metadata.get('title')}\n"
f"Content: {d.page_content}\n"
for d in docs)
async def async_get_context(self, query, max_results=5):
"""Get relevant context from vector store"""
results = await self.vector_store.asimilarity_search(query=query, k=max_results, filter=self.filter)
return self.__pretty_print_docs(results)
class ContextCompressor:
def __init__(self, documents, embeddings, max_results=5, **kwargs):
self.max_results = max_results
self.documents = documents
self.kwargs = kwargs
self.embeddings = embeddings
self.similarity_threshold = os.environ.get("SIMILARITY_THRESHOLD", 0.35)
def __get_contextual_retriever(self):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
relevance_filter = EmbeddingsFilter(embeddings=self.embeddings,
similarity_threshold=self.similarity_threshold)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, relevance_filter]
)
base_retriever = SearchAPIRetriever(
pages=self.documents
)
contextual_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_retriever
)
return contextual_retriever
def __pretty_print_docs(self, docs, top_n):
return f"\n".join(f"Source: {d.metadata.get('source')}\n"
f"Title: {d.metadata.get('title')}\n"
f"Content: {d.page_content}\n"
for i, d in enumerate(docs) if i < top_n)
async def async_get_context(self, query, max_results=5, cost_callback=None):
compressed_docs = self.__get_contextual_retriever()
if cost_callback:
cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents))
relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query)
return self.__pretty_print_docs(relevant_docs, max_results)
class WrittenContentCompressor:
def __init__(self, documents, embeddings, similarity_threshold, **kwargs):
self.documents = documents
self.kwargs = kwargs
self.embeddings = embeddings
self.similarity_threshold = similarity_threshold
def __get_contextual_retriever(self):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
relevance_filter = EmbeddingsFilter(embeddings=self.embeddings,
similarity_threshold=self.similarity_threshold)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, relevance_filter]
)
base_retriever = SectionRetriever(
sections=self.documents
)
contextual_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_retriever
)
return contextual_retriever
def __pretty_docs_list(self, docs, top_n):
return [f"Title: {d.metadata.get('section_title')}\nContent: {d.page_content}\n" for i, d in enumerate(docs) if i < top_n]
async def async_get_context(self, query, max_results=5, cost_callback=None):
compressed_docs = self.__get_contextual_retriever()
if cost_callback:
cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents))
relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query)
return self.__pretty_docs_list(relevant_docs, max_results)