Spaces:
Runtime error

dzenzzz commited on
Commit
e0cb517
·
1 Parent(s): 0f3ebb5

resolve conflicts

Browse files
Files changed (2) hide show
  1. app.py +2 -7
  2. neural_searcher.py +21 -17
app.py CHANGED
@@ -11,12 +11,7 @@ neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
11
 
12
  @app.get("/api/search")
13
  async def search(q: str):
14
- # if not q:
15
- # raise HTTPException(status_code=400, detail="Bad request.")
16
 
17
- # try:
18
- data = await neural_searcher.search(text=q)
19
- return data
20
- # except:
21
- # raise HTTPException(status_code=500, detail="Internal server error.")
22
 
 
11
 
12
  @app.get("/api/search")
13
  async def search(q: str):
14
+ data = await neural_searcher.search(text=q)
15
+ return data
16
 
 
 
 
 
 
17
 
neural_searcher.py CHANGED
@@ -1,5 +1,5 @@
1
  from qdrant_client import QdrantClient
2
- from fastembed import SparseTextEmbedding
3
  from qdrant_client import QdrantClient, models
4
  from sentence_transformers import SentenceTransformer
5
  import os
@@ -10,34 +10,38 @@ class NeuralSearcher:
10
  self.collection_name = collection_name
11
  self.dense_model = SentenceTransformer(os.getenv('DENSE_MODEL'),device="cpu")
12
  self.sparse_model = SparseTextEmbedding(os.getenv('SPARSE_MODEL'))
 
13
  self.qdrant_client = QdrantClient(os.getenv('QDRANT_URL'),api_key=os.getenv('QDRANT_API_KEY'))
14
 
15
  async def search(self, text: str):
16
 
17
- dense_query = self.dense_model.encode(text).tolist()
18
- sparse_query = self.sparse_model.query_embed(text)
19
-
20
- search_result = self.qdrant_client.query_points_groups(
21
- collection_name= self.collection_name,
22
- group_by="dbid",
23
- prefetch=[
24
  models.Prefetch(
25
  query=dense_query,
26
  using=os.getenv('DENSE_MODEL'),
27
  limit=100
28
  ),
29
  models.Prefetch(
30
- query=next(sparse_query).as_object(),
31
  using=os.getenv('SPARSE_MODEL'),
32
  limit=100
33
  )
34
- ],
35
- query=models.FusionQuery(
36
- fusion=models.Fusion.RRF
37
- ),
 
 
 
 
 
 
38
  score_threshold=0.8,
39
  limit = 10
40
- ).points
41
-
42
- payloads = [hit.payload for hit in search_result]
43
- return payloads
 
1
  from qdrant_client import QdrantClient
2
+ from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
3
  from qdrant_client import QdrantClient, models
4
  from sentence_transformers import SentenceTransformer
5
  import os
 
10
  self.collection_name = collection_name
11
  self.dense_model = SentenceTransformer(os.getenv('DENSE_MODEL'),device="cpu")
12
  self.sparse_model = SparseTextEmbedding(os.getenv('SPARSE_MODEL'))
13
+ self.late_interaction_model = LateInteractionTextEmbedding(os.getenv('LATE_INTERACTION_MODEL'))
14
  self.qdrant_client = QdrantClient(os.getenv('QDRANT_URL'),api_key=os.getenv('QDRANT_API_KEY'))
15
 
16
  async def search(self, text: str):
17
 
18
+ dense_query = next(self.dense_model.encode(text))
19
+ sparse_query = next(self.sparse_model.query_embed(text))
20
+ late_query = next(self.late_interaction_embedding_model.query_embed(text))
21
+
22
+ prefetch = [
 
 
23
  models.Prefetch(
24
  query=dense_query,
25
  using=os.getenv('DENSE_MODEL'),
26
  limit=100
27
  ),
28
  models.Prefetch(
29
+ query=models.SparseVector(**sparse_query.as_object()),
30
  using=os.getenv('SPARSE_MODEL'),
31
  limit=100
32
  )
33
+ ]
34
+
35
+ search_result = self.qdrant_client.query_points_groups(
36
+ collection_name= self.collection_name,
37
+ group_by="dbid",
38
+ prefetch=prefetch,
39
+ group_size=3,
40
+ query=late_query,
41
+ using=os.getenv('LATE_INTERACTION_MODEL'),
42
+ with_payload=True,
43
  score_threshold=0.8,
44
  limit = 10
45
+ ).groups
46
+
47
+ return search_result