Spaces:
Running
Running
from typing import Union | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
from langchain_qdrant import Qdrant | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_core.prompts import format_document, PromptTemplate | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models as qdrant_models | |
from supabase import create_client, Client | |
supabase_client:Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_KEY")) | |
COLLECTIONS = [ | |
{ | |
"embedding_model": OpenAIEmbeddings(model="text-embedding-3-large", dimensions=1024), | |
"chunk_size": 1024, | |
"name": "1024-openaiLarge-1024", | |
}, | |
] | |
class Retriever(): | |
def __init__( | |
self, | |
collection_index:int = 0, | |
use_doctrines:bool = True, | |
search_type:str = "similarity", | |
k:Union[int, None] = None, | |
similarity_threshold:float = 0.0, | |
): | |
self.collection_index = collection_index | |
self.use_doctrines = use_doctrines | |
self.search_type = search_type | |
self.k = k | |
self.similarity_threshold = similarity_threshold | |
def _get_vectorstore( | |
self, | |
) -> Qdrant: | |
client = QdrantClient( | |
url=os.environ.get("QDRANT_CLUSTER_URL"), | |
api_key=os.environ.get("QDRANT_API_KEY"), | |
prefer_grpc=True | |
) | |
store = Qdrant( | |
client=client, | |
embeddings=COLLECTIONS[self.collection_index]["embedding_model"], | |
collection_name=COLLECTIONS[self.collection_index]["name"] | |
) | |
return store | |
def _retrieve( | |
self, | |
query:str, | |
) -> list: | |
if self.k is None: | |
self.k = int(15000/COLLECTIONS[self.collection_index]["chunk_size"]) | |
vectorstore = self._get_vectorstore() | |
if not self.use_doctrines: | |
filter = qdrant_models.Filter( | |
must=[ | |
qdrant_models.FieldCondition( | |
key="metadata.type", | |
match=qdrant_models.MatchValue(value='Prassi') | |
) | |
] | |
) | |
else: | |
filter = None | |
if self.search_type == "similarity": | |
docs = vectorstore.similarity_search_with_score( | |
query, | |
k=self.k, | |
filter=filter, | |
score_threshold=self.similarity_threshold | |
) | |
elif self.search_type == "mmr": | |
docs = vectorstore.max_marginal_relevance_search_with_score_by_vector( | |
vectorstore._embed_query(query), | |
k=self.k, | |
filter=filter, | |
score_threshold=self.similarity_threshold | |
) | |
return docs | |
def _combine_documents( | |
docs:list, | |
document_separator:str = "\n\n----------\n\n", | |
) -> str: | |
DOCUMENT_PROMPT = PromptTemplate.from_template( | |
template="UUID: {supabase_id}\nTITOLO: {title}\nTIPO: {type}\nCONTENUTO: {page_content}" | |
) | |
doc_strings = [format_document(doc, DOCUMENT_PROMPT) for doc, _ in docs] | |
return document_separator.join(doc_strings) |