File size: 1,897 Bytes
321de76
e0cb517
321de76
 
0dfabf7
b7dc427
321de76
 
9e9178e
321de76
 
0dfabf7
df02cd1
 
 
321de76
9e9178e
 
6644c19
e0cb517
9e9178e
e0cb517
 
b7dc427
 
df02cd1
9e9178e
b7dc427
 
e0cb517
df02cd1
9e9178e
b7dc427
e0cb517
 
9e9178e
e0cb517
 
9e9178e
 
 
df02cd1
e0cb517
403610d
9e9178e
 
 
 
 
 
e0cb517
9e9178e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from qdrant_client import QdrantClient
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from config import DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY,HUGGING_FACE_API_KEY
import os

class NeuralSearcher:
    
    def __init__(self, collection_name):
        self.collection_name = collection_name
        self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
        self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
        self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
        self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY)
    
    async def search(self, text: str):
        
        dense_query = self.dense_model.encode(text).tolist()
        sparse_query = next(self.sparse_model.query_embed(text))
        # late_query = next(self.late_interaction_model.query_embed(text))

        prefetch = [
                models.Prefetch(
                    query=dense_query,
                    using=DENSE_MODEL,
                    limit=200
                ),
                models.Prefetch(
                    query=models.SparseVector(**sparse_query.as_object()),
                    using=SPARSE_MODEL,
                    limit=200
                )
        ]

        search_result = self.qdrant_client.query_points(
            collection_name= self.collection_name,
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF,
            ),
            # using=LATE_INTERACTION_MODEL,
            with_payload=True,
            limit = 10
        ).points

        data = []

        for hit in search_result:
            data.append(hit.payload)
        
        return data