Spaces:
Runtime error

File size: 1,813 Bytes
321de76
e0cb517
321de76
 
0dfabf7
321de76
f61c6a5
9e9178e
321de76
 
0dfabf7
df02cd1
 
a2e724b
321de76
9e9178e
 
6644c19
e0cb517
 
 
b7dc427
 
f61c6a5
4d76c9b
b7dc427
 
e0cb517
f61c6a5
4d76c9b
b7dc427
e0cb517
 
9e9178e
e0cb517
 
9e9178e
 
 
f44c07e
 
 
9e9178e
 
 
 
 
 
e0cb517
9e9178e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from qdrant_client import QdrantClient
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY

class DocSearcher:
    
    def __init__(self, collection_name):
        self.collection_name = collection_name
        self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
        self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
        self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
        self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
    
    async def search(self, text: str):
        
        dense_query = self.dense_model.encode(text).tolist()
        sparse_query = next(self.sparse_model.query_embed(text))

        prefetch = [
                models.Prefetch(
                    query=dense_query,
                    using=DENSE_MODEL,
                    limit=10000
                ),
                models.Prefetch(
                    query=models.SparseVector(**sparse_query.as_object()),
                    using=SPARSE_MODEL,
                    limit=10000
                )
        ]

        search_result = self.qdrant_client.query_points(
            collection_name= self.collection_name,
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF,
            ),
            with_payload=False,
            score_threshold=0.8,
            limit = 5
        ).points

        data = []

        for hit in search_result:
            data.append(hit.payload)
        
        return data