Spaces:
Running
Running
File size: 3,197 Bytes
d46cc41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from typing import Union
import os
from dotenv import load_dotenv
load_dotenv()
from langchain_qdrant import Qdrant
from langchain_openai import OpenAIEmbeddings
from langchain_core.prompts import format_document, PromptTemplate
from qdrant_client import QdrantClient
from qdrant_client.http import models as qdrant_models
from supabase import create_client, Client
supabase_client:Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_KEY"))
COLLECTIONS = [
{
"embedding_model": OpenAIEmbeddings(model="text-embedding-3-large", dimensions=1024),
"chunk_size": 1024,
"name": "1024-openaiLarge-1024",
},
]
class Retriever():
def __init__(
self,
collection_index:int = 0,
use_doctrines:bool = True,
search_type:str = "similarity",
k:Union[int, None] = None,
similarity_threshold:float = 0.0,
):
self.collection_index = collection_index
self.use_doctrines = use_doctrines
self.search_type = search_type
self.k = k
self.similarity_threshold = similarity_threshold
def _get_vectorstore(
self,
) -> Qdrant:
client = QdrantClient(
url=os.environ.get("QDRANT_CLUSTER_URL"),
api_key=os.environ.get("QDRANT_API_KEY"),
prefer_grpc=True
)
store = Qdrant(
client=client,
embeddings=COLLECTIONS[self.collection_index]["embedding_model"],
collection_name=COLLECTIONS[self.collection_index]["name"]
)
return store
def _retrieve(
self,
query:str,
) -> list:
if self.k is None:
self.k = int(15000/COLLECTIONS[self.collection_index]["chunk_size"])
vectorstore = self._get_vectorstore()
if not self.use_doctrines:
filter = qdrant_models.Filter(
must=[
qdrant_models.FieldCondition(
key="metadata.type",
match=qdrant_models.MatchValue(value='Prassi')
)
]
)
else:
filter = None
if self.search_type == "similarity":
docs = vectorstore.similarity_search_with_score(
query,
k=self.k,
filter=filter,
score_threshold=self.similarity_threshold
)
elif self.search_type == "mmr":
docs = vectorstore.max_marginal_relevance_search_with_score_by_vector(
vectorstore._embed_query(query),
k=self.k,
filter=filter,
score_threshold=self.similarity_threshold
)
return docs
def _combine_documents(
docs:list,
document_separator:str = "\n\n----------\n\n",
) -> str:
DOCUMENT_PROMPT = PromptTemplate.from_template(
template="UUID: {supabase_id}\nTITOLO: {title}\nTIPO: {type}\nCONTENUTO: {page_content}"
)
doc_strings = [format_document(doc, DOCUMENT_PROMPT) for doc, _ in docs]
return document_separator.join(doc_strings) |