Spaces:
Running
Running
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchText | |
| from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding | |
| from qdrant_client import QdrantClient, models | |
| from reranker import Reranker | |
| from sentence_transformers import SentenceTransformer | |
| from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY | |
| class DocSearcher: | |
| def __init__(self, collection_name): | |
| self.collection_name = collection_name | |
| self.reranker = Reranker() | |
| self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY) | |
| self.model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B",device="cpu") | |
| self.sparse_model = SparseTextEmbedding(SPARSE_MODEL) | |
| self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL) | |
| self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30) | |
| async def search(self, text: str,type:int, law_type: str | None = None, offset: int = 0): | |
| dense_query = self.dense_model.encode(text).tolist() | |
| sparse_query = next(self.sparse_model.query_embed(text)) | |
| if type == 0: | |
| prefetch = [ | |
| models.Prefetch( | |
| query=dense_query, | |
| using=DENSE_MODEL, | |
| limit=100 | |
| ), | |
| models.Prefetch( | |
| query=models.SparseVector(**sparse_query.as_object()), | |
| using=SPARSE_MODEL, | |
| limit=100 | |
| ) | |
| ] | |
| else: | |
| prefetch = [ | |
| models.Prefetch( | |
| query=dense_query, | |
| using=DENSE_MODEL, | |
| limit=100 | |
| ) | |
| ] | |
| if type == 0 and law_type is not None: | |
| filter = Filter( | |
| must=[ | |
| FieldCondition( | |
| key="tip_dokumenta", | |
| match=MatchValue(value=type) | |
| ), | |
| FieldCondition( | |
| key="naziv_suda", | |
| match=MatchText(text=law_type) | |
| ), | |
| ] | |
| ) | |
| elif type == 0: | |
| filter = Filter( | |
| must=[ | |
| FieldCondition( | |
| key="tip_dokumenta", | |
| match=MatchValue(value=type) | |
| ) | |
| ] | |
| ) | |
| elif type == 1 and law_type is not None: | |
| filter = Filter( | |
| must=[ | |
| FieldCondition( | |
| key="tip_dokumenta", | |
| match=MatchValue(value=type) | |
| ), | |
| FieldCondition( | |
| key="vrsta_akta", | |
| match=MatchValue(value=law_type) | |
| ), | |
| ] | |
| ) | |
| else: | |
| return [] | |
| search_result = self.qdrant_client.query_points( | |
| collection_name= self.collection_name, | |
| query_filter=filter, | |
| prefetch=prefetch, | |
| query=models.FusionQuery( | |
| fusion=models.Fusion.RRF, | |
| ), | |
| with_payload=True, | |
| limit = 10, | |
| offset = offset | |
| ).points | |
| data = [] | |
| for hit in search_result: | |
| data.append(hit.payload) | |
| return data | |
| async def search_temp(self, text: str): | |
| queries = [text] | |
| dense_query = self.model.encode(text).tolist() | |
| # sparse_query = next(self.sparse_model.query_embed(text)) | |
| prefetch = [ | |
| models.Prefetch( | |
| query=dense_query, | |
| using="Qwen/Qwen3-Embedding-0.6B", | |
| limit=100 | |
| ), | |
| ] | |
| search_result = self.qdrant_client.query_points( | |
| collection_name= "sl-list", | |
| prefetch=prefetch, | |
| with_payload=True, | |
| limit = 100, | |
| ).points | |
| data = [] | |
| for hit in search_result: | |
| data.append(hit.payload["tekst"]) | |
| scores = self.reranker.compute_logits(queries,data) | |
| return scores |