dzenzzz commited on
Commit
8f5a157
·
1 Parent(s): d1d999e

adds qwen for testing

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. doc_searcher.py +37 -0
app.py CHANGED
@@ -20,7 +20,8 @@ ALLOWED_API_KEY = str(API_KEY)
20
  async def search(q: str, type: int, lt: str | None = None, offset: int = 0):
21
  query = q.lower()
22
  xss = nh3.clean(query)
23
- data = await doc_searcher.search(text=xss,type=type,law_type=lt,offset=offset)
 
24
  return data
25
 
26
  @app.get("/api/suggestions")
 
20
  async def search(q: str, type: int, lt: str | None = None, offset: int = 0):
21
  query = q.lower()
22
  xss = nh3.clean(query)
23
+ # data = await doc_searcher.search(text=xss,type=type,law_type=lt,offset=offset)
24
+ data = await doc_searcher.search_temp(text=xss)
25
  return data
26
 
27
  @app.get("/api/suggestions")
doc_searcher.py CHANGED
@@ -10,6 +10,7 @@ class DocSearcher:
10
  def __init__(self, collection_name):
11
  self.collection_name = collection_name
12
  self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
 
13
  self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
14
  self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
15
  self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
@@ -93,6 +94,42 @@ class DocSearcher:
93
 
94
  data = []
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  for hit in search_result:
97
  data.append(hit.payload)
98
 
 
10
  def __init__(self, collection_name):
11
  self.collection_name = collection_name
12
  self.dense_model = SentenceTransformer(DENSE_MODEL,device="cpu",token=HUGGING_FACE_API_KEY)
13
+ self.model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B",device="cpu")
14
  self.sparse_model = SparseTextEmbedding(SPARSE_MODEL)
15
  self.late_interaction_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
16
  self.qdrant_client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY,timeout=30)
 
94
 
95
  data = []
96
 
97
+ for hit in search_result:
98
+ data.append(hit.payload)
99
+
100
+ return data
101
+
102
+ async def search_temp(self, text: str):
103
+
104
+ dense_query = self.model.encode(text).tolist()
105
+ sparse_query = next(self.sparse_model.query_embed(text))
106
+
107
+ prefetch = [
108
+ models.Prefetch(
109
+ query=dense_query,
110
+ using=DENSE_MODEL,
111
+ limit=100
112
+ ),
113
+ # models.Prefetch(
114
+ # query=models.SparseVector(**sparse_query.as_object()),
115
+ # using=SPARSE_MODEL,
116
+ # limit=100
117
+ # )
118
+ ]
119
+
120
+ search_result = self.qdrant_client.query_points(
121
+ collection_name= self.collection_name,
122
+ query_filter=filter,
123
+ prefetch=prefetch,
124
+ query=models.FusionQuery(
125
+ fusion=models.Fusion.RRF,
126
+ ),
127
+ with_payload=True,
128
+ limit = 10,
129
+ ).points
130
+
131
+ data = []
132
+
133
  for hit in search_result:
134
  data.append(hit.payload)
135