Spaces:
Runtime error

File size: 2,846 Bytes
321de76
815e080
e0cb517
321de76
 
0dfabf7
321de76
f61c6a5
9e9178e
321de76
 
0dfabf7
df02cd1
 
a2e724b
321de76
815e080
9e9178e
6644c19
e0cb517
 
 
b7dc427
 
f61c6a5
815e080
b7dc427
 
e0cb517
f61c6a5
815e080
b7dc427
e0cb517
 
815e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e9178e
e0cb517
815e080
e0cb517
9e9178e
 
 
5c7babc
815e080
 
9e9178e
 
 
 
 
 
e0cb517
9e9178e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY

class DocSearcher:
    
    def __init__(self, collection_name):
        self.collection_name = collection_name
        self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
        self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
        self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
        self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
    
    async def search(self, text: str,type:int, law_type: str | None = None, offset: int = 0):
        
        dense_query = self.dense_model.encode(text).tolist()
        sparse_query = next(self.sparse_model.query_embed(text))

        prefetch = [
                models.Prefetch(
                    query=dense_query,
                    using=DENSE_MODEL,
                    limit=100
                ),
                models.Prefetch(
                    query=models.SparseVector(**sparse_query.as_object()),
                    using=SPARSE_MODEL,
                    limit=100
                )
        ]

        if type == 2:
            filter = None
        elif type == 1 and law_type is not None:
            filter = Filter(
                must=[
                    FieldCondition(
                        key="tip_dokumenta",
                        match=MatchValue(value=type)
                    ),
                    FieldCondition(
                        key="vrsta_akta",
                        match=MatchValue(value=law_type)
                    ),
                ],
                must_not=[
                    FieldCondition(key="status", match=MatchValue(value="Nevažeći")),
                ]
            )
        else:
            filter = Filter(
                must=[
                    FieldCondition(
                        key="tip_dokumenta",
                        match=MatchValue(value=type)
                    ),
                ]
            )
        search_result = self.qdrant_client.query_points(
            collection_name= self.collection_name,
            query_filter=filter,
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF,
            ),
            with_payload=True,
            limit = 10,
            offset = offset
        ).points

        data = []

        for hit in search_result:
            data.append(hit.payload)
        
        return data