Spaces:
Runtime error
Runtime error
File size: 2,846 Bytes
321de76 815e080 e0cb517 321de76 0dfabf7 321de76 f61c6a5 9e9178e 321de76 0dfabf7 df02cd1 a2e724b 321de76 815e080 9e9178e 6644c19 e0cb517 b7dc427 f61c6a5 815e080 b7dc427 e0cb517 f61c6a5 815e080 b7dc427 e0cb517 815e080 9e9178e e0cb517 815e080 e0cb517 9e9178e 5c7babc 815e080 9e9178e e0cb517 9e9178e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY
class DocSearcher:
def __init__(self, collection_name):
self.collection_name = collection_name
self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
async def search(self, text: str,type:int, law_type: str | None = None, offset: int = 0):
dense_query = self.dense_model.encode(text).tolist()
sparse_query = next(self.sparse_model.query_embed(text))
prefetch = [
models.Prefetch(
query=dense_query,
using=DENSE_MODEL,
limit=100
),
models.Prefetch(
query=models.SparseVector(**sparse_query.as_object()),
using=SPARSE_MODEL,
limit=100
)
]
if type == 2:
filter = None
elif type == 1 and law_type is not None:
filter = Filter(
must=[
FieldCondition(
key="tip_dokumenta",
match=MatchValue(value=type)
),
FieldCondition(
key="vrsta_akta",
match=MatchValue(value=law_type)
),
],
must_not=[
FieldCondition(key="status", match=MatchValue(value="Nevažeći")),
]
)
else:
filter = Filter(
must=[
FieldCondition(
key="tip_dokumenta",
match=MatchValue(value=type)
),
]
)
search_result = self.qdrant_client.query_points(
collection_name= self.collection_name,
query_filter=filter,
prefetch=prefetch,
query=models.FusionQuery(
fusion=models.Fusion.RRF,
),
with_payload=True,
limit = 10,
offset = offset
).points
data = []
for hit in search_result:
data.append(hit.payload)
return data |