from typing import Union import os from dotenv import load_dotenv load_dotenv() from langchain_qdrant import Qdrant from langchain_openai import OpenAIEmbeddings from langchain_core.prompts import format_document, PromptTemplate from qdrant_client import QdrantClient from qdrant_client.http import models as qdrant_models from supabase import create_client, Client supabase_client:Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_KEY")) COLLECTIONS = [ { "embedding_model": OpenAIEmbeddings(model="text-embedding-3-large", dimensions=1024), "chunk_size": 1024, "name": "1024-openaiLarge-1024", }, ] class Retriever(): def __init__( self, collection_index:int = 0, use_doctrines:bool = True, search_type:str = "similarity", k:Union[int, None] = None, similarity_threshold:float = 0.0, ): self.collection_index = collection_index self.use_doctrines = use_doctrines self.search_type = search_type self.k = k self.similarity_threshold = similarity_threshold def _get_vectorstore( self, ) -> Qdrant: client = QdrantClient( url=os.environ.get("QDRANT_CLUSTER_URL"), api_key=os.environ.get("QDRANT_API_KEY"), prefer_grpc=True ) store = Qdrant( client=client, embeddings=COLLECTIONS[self.collection_index]["embedding_model"], collection_name=COLLECTIONS[self.collection_index]["name"] ) return store def _retrieve( self, query:str, ) -> list: if self.k is None: self.k = int(15000/COLLECTIONS[self.collection_index]["chunk_size"]) vectorstore = self._get_vectorstore() if not self.use_doctrines: filter = qdrant_models.Filter( must=[ qdrant_models.FieldCondition( key="metadata.type", match=qdrant_models.MatchValue(value='Prassi') ) ] ) else: filter = None if self.search_type == "similarity": docs = vectorstore.similarity_search_with_score( query, k=self.k, filter=filter, score_threshold=self.similarity_threshold ) elif self.search_type == "mmr": docs = vectorstore.max_marginal_relevance_search_with_score_by_vector( vectorstore._embed_query(query), k=self.k, filter=filter, score_threshold=self.similarity_threshold ) return docs def _combine_documents( docs:list, document_separator:str = "\n\n----------\n\n", ) -> str: DOCUMENT_PROMPT = PromptTemplate.from_template( template="UUID: {supabase_id}\nTITOLO: {title}\nTIPO: {type}\nCONTENUTO: {page_content}" ) doc_strings = [format_document(doc, DOCUMENT_PROMPT) for doc, _ in docs] return document_separator.join(doc_strings)