Spaces:
Runtime error

File size: 4,553 Bytes
321de76
10d2c79
e0cb517
321de76
bac4585
321de76
0dfabf7
321de76
f61c6a5
9e9178e
321de76
 
bac4585
0dfabf7
bac4585
df02cd1
 
a2e724b
321de76
815e080
9e9178e
6644c19
e0cb517
 
d1d999e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0cb517
7e03879
815e080
 
 
 
 
 
 
10d2c79
 
815e080
 
 
7e03879
815e080
 
 
 
 
10d2c79
815e080
 
7e03879
a41c542
 
 
 
 
 
 
 
d17d194
a41c542
 
 
10d2c79
 
 
9e9178e
e0cb517
815e080
e0cb517
9e9178e
 
 
5c7babc
10d2c79
815e080
9e9178e
 
 
 
8f5a157
 
 
 
 
 
 
bac4585
8f5a157
ff16f91
8f5a157
 
 
 
bac4585
 
8f5a157
 
 
 
9db0c3e
8f5a157
c6b5790
 
 
8f5a157
bac4585
8f5a157
 
 
 
9e9178e
bac4585
8d0bbf2
bac4585
e0cb517
8d0bbf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchText
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from reranker import Reranker
from sentence_transformers import SentenceTransformer
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY

class DocSearcher:
    
    def __init__(self, collection_name):
        self.collection_name = collection_name
        self.reranker = Reranker()
        self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
        self.model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B",device="cpu")
        self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
        self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
        self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
    
    async def search(self, text: str,type:int, law_type: str | None = None, offset: int = 0):
        
        dense_query = self.dense_model.encode(text).tolist()
        sparse_query = next(self.sparse_model.query_embed(text))

        if type == 0:
            prefetch = [
                    models.Prefetch(
                        query=dense_query,
                        using=DENSE_MODEL,
                        limit=100
                    ),
                    models.Prefetch(
                        query=models.SparseVector(**sparse_query.as_object()),
                        using=SPARSE_MODEL,
                        limit=100
                    )
            ]
        else:
            prefetch = [
                    models.Prefetch(
                        query=dense_query,
                        using=DENSE_MODEL,
                        limit=100
                    )
            ]

        if type == 0 and law_type is not None:
            filter = Filter(
                must=[
                    FieldCondition(
                        key="tip_dokumenta",
                        match=MatchValue(value=type)
                    ),
                    FieldCondition(
                        key="naziv_suda",
                        match=MatchText(text=law_type)
                    ),
                ]
            )
        elif type == 0:
            filter = Filter(
                must=[
                    FieldCondition(
                        key="tip_dokumenta",
                        match=MatchValue(value=type)
                    )
                ]
            )
        elif type == 1 and law_type is not None:
            filter = Filter(
                must=[
                    FieldCondition(
                        key="tip_dokumenta",
                        match=MatchValue(value=type)
                    ),
                    FieldCondition(
                        key="vrsta_akta",
                        match=MatchValue(value=law_type)
                    ),
                ]
            )
        else:
            return []
            
        search_result = self.qdrant_client.query_points(
            collection_name= self.collection_name,
            query_filter=filter,
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF,
            ),
            with_payload=True,
            limit = 10,
            offset = offset
        ).points

        data = []

        for hit in search_result:
            data.append(hit.payload)
        
        return data
    
    async def search_temp(self, text: str):
        
        queries = [text]
        dense_query = self.model.encode(text).tolist()
        # sparse_query = next(self.sparse_model.query_embed(text))

        prefetch = [
                models.Prefetch(
                    query=dense_query,
                    using="Qwen/Qwen3-Embedding-0.6B",
                    limit=100
                ),
        ]
            
        search_result = self.qdrant_client.query_points(
            collection_name= "sl-list",
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF,
            ),
            with_payload=True,
            limit = 100,
        ).points

        data = []

        for hit in search_result:
            data.append(hit.payload["tekst"])
            
        scores = self.reranker.compute_logits(queries,data)
        
        return scores