Spaces:
Runtime error
Runtime error
from qdrant_client import QdrantClient | |
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchText | |
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding | |
from qdrant_client import QdrantClient, models | |
from reranker import Reranker | |
from sentence_transformers import SentenceTransformer | |
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY | |
class DocSearcher: | |
def __init__(self, collection_name): | |
self.collection_name = collection_name | |
self.reranker = Reranker() | |
self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY) | |
self.model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B",device="cpu") | |
self.sparse_model = SparseTextEmbedding(SPARSE_MODEL) | |
self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL) | |
self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30) | |
async def search(self, text: str,type:int, law_type: str | None = None, offset: int = 0): | |
dense_query = self.dense_model.encode(text).tolist() | |
sparse_query = next(self.sparse_model.query_embed(text)) | |
if type == 0: | |
prefetch = [ | |
models.Prefetch( | |
query=dense_query, | |
using=DENSE_MODEL, | |
limit=100 | |
), | |
models.Prefetch( | |
query=models.SparseVector(**sparse_query.as_object()), | |
using=SPARSE_MODEL, | |
limit=100 | |
) | |
] | |
else: | |
prefetch = [ | |
models.Prefetch( | |
query=dense_query, | |
using=DENSE_MODEL, | |
limit=100 | |
) | |
] | |
if type == 0 and law_type is not None: | |
filter = Filter( | |
must=[ | |
FieldCondition( | |
key="tip_dokumenta", | |
match=MatchValue(value=type) | |
), | |
FieldCondition( | |
key="naziv_suda", | |
match=MatchText(text=law_type) | |
), | |
] | |
) | |
elif type == 0: | |
filter = Filter( | |
must=[ | |
FieldCondition( | |
key="tip_dokumenta", | |
match=MatchValue(value=type) | |
) | |
] | |
) | |
elif type == 1 and law_type is not None: | |
filter = Filter( | |
must=[ | |
FieldCondition( | |
key="tip_dokumenta", | |
match=MatchValue(value=type) | |
), | |
FieldCondition( | |
key="vrsta_akta", | |
match=MatchValue(value=law_type) | |
), | |
] | |
) | |
else: | |
return [] | |
search_result = self.qdrant_client.query_points( | |
collection_name= self.collection_name, | |
query_filter=filter, | |
prefetch=prefetch, | |
query=models.FusionQuery( | |
fusion=models.Fusion.RRF, | |
), | |
with_payload=True, | |
limit = 10, | |
offset = offset | |
).points | |
data = [] | |
for hit in search_result: | |
data.append(hit.payload) | |
return data | |
async def search_temp(self, text: str): | |
queries = [text] | |
dense_query = self.model.encode(text).tolist() | |
# sparse_query = next(self.sparse_model.query_embed(text)) | |
prefetch = [ | |
models.Prefetch( | |
query=dense_query, | |
using="Qwen/Qwen3-Embedding-0.6B", | |
limit=100 | |
), | |
] | |
search_result = self.qdrant_client.query_points( | |
collection_name= "sl-list", | |
prefetch=prefetch, | |
query=models.FusionQuery( | |
fusion=models.Fusion.RRF, | |
), | |
with_payload=True, | |
limit = 100, | |
).points | |
data = [] | |
for hit in search_result: | |
data.append(hit.payload["tekst"]) | |
scores = self.reranker.compute_logits(queries,data) | |
return scores |