Spaces:
Runtime error

sq / doc_searcher.py
dzenzzz's picture
Update doc_searcher.py
815e080 verified
raw
history blame
2.85 kB
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY
class DocSearcher:
def __init__(self, collection_name):
self.collection_name = collection_name
self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
async def search(self, text: str,type:int, law_type: str | None = None, offset: int = 0):
dense_query = self.dense_model.encode(text).tolist()
sparse_query = next(self.sparse_model.query_embed(text))
prefetch = [
models.Prefetch(
query=dense_query,
using=DENSE_MODEL,
limit=100
),
models.Prefetch(
query=models.SparseVector(**sparse_query.as_object()),
using=SPARSE_MODEL,
limit=100
)
]
if type == 2:
filter = None
elif type == 1 and law_type is not None:
filter = Filter(
must=[
FieldCondition(
key="tip_dokumenta",
match=MatchValue(value=type)
),
FieldCondition(
key="vrsta_akta",
match=MatchValue(value=law_type)
),
],
must_not=[
FieldCondition(key="status", match=MatchValue(value="Nevažeći")),
]
)
else:
filter = Filter(
must=[
FieldCondition(
key="tip_dokumenta",
match=MatchValue(value=type)
),
]
)
search_result = self.qdrant_client.query_points(
collection_name= self.collection_name,
query_filter=filter,
prefetch=prefetch,
query=models.FusionQuery(
fusion=models.Fusion.RRF,
),
with_payload=True,
limit = 10,
offset = offset
).points
data = []
for hit in search_result:
data.append(hit.payload)
return data