dzenzzz commited on
Commit
403610d
·
1 Parent(s): b7dc427

changes to fast api and neural searcher

Browse files
Files changed (3) hide show
  1. app.py +11 -3
  2. neural_searcher.py +6 -4
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI
2
  from neural_searcher import NeuralSearcher
3
  from huggingface_hub import login
4
  import os
@@ -10,5 +10,13 @@ app = FastAPI()
10
  neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
11
 
12
  @app.get("/api/search")
13
- def search_startup(q: str):
14
- return {"result": neural_searcher.search(text=q)}
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from neural_searcher import NeuralSearcher
3
  from huggingface_hub import login
4
  import os
 
10
  neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
11
 
12
  @app.get("/api/search")
13
+ async def search(q: str):
14
+ if not q:
15
+ raise HTTPException(status_code=400, detail="Bad request.")
16
+
17
+ try:
18
+ data = await neural_searcher.search(text=q)
19
+ return data
20
+ except:
21
+ raise HTTPException(status_code=500, detail="Internal server error.")
22
+
neural_searcher.py CHANGED
@@ -17,24 +17,26 @@ class NeuralSearcher:
17
  dense_query = self.dense_model.encode(text).tolist()
18
  sparse_query = self.sparse_model.query_embed(text)
19
 
20
- search_result = self.qdrant_client.query_points(
21
  collection_name= self.collection_name,
 
22
  prefetch=[
23
  models.Prefetch(
24
  query=dense_query,
25
  using=os.getenv('DENSE_MODEL'),
26
- limit=5
27
  ),
28
  models.Prefetch(
29
  query=next(sparse_query).as_object(),
30
  using=os.getenv('SPARSE_MODEL'),
31
- limit=5
32
  )
33
  ],
34
  query=models.FusionQuery(
35
  fusion=models.Fusion.RRF
36
  ),
37
- limit = 9
 
38
  ).points
39
 
40
  payloads = [hit.payload for hit in search_result]
 
17
  dense_query = self.dense_model.encode(text).tolist()
18
  sparse_query = self.sparse_model.query_embed(text)
19
 
20
+ search_result = self.qdrant_client.query_points_groups(
21
  collection_name= self.collection_name,
22
+ group_by="dbid",
23
  prefetch=[
24
  models.Prefetch(
25
  query=dense_query,
26
  using=os.getenv('DENSE_MODEL'),
27
+ limit=100
28
  ),
29
  models.Prefetch(
30
  query=next(sparse_query).as_object(),
31
  using=os.getenv('SPARSE_MODEL'),
32
+ limit=100
33
  )
34
  ],
35
  query=models.FusionQuery(
36
  fusion=models.Fusion.RRF
37
  ),
38
+ score_threshold=0.8,
39
+ limit = 10
40
  ).points
41
 
42
  payloads = [hit.payload for hit in search_result]
requirements.txt CHANGED
@@ -8,3 +8,4 @@ python-dotenv
8
  qdrant-client
9
  qdrant-client[fastembed]>=1.8.2
10
  sentence-transformers
 
 
8
  qdrant-client
9
  qdrant-client[fastembed]>=1.8.2
10
  sentence-transformers
11
+ firebase