Kevin Hu commited on
Commit
6d597a0
·
1 Parent(s): acd1df1

debug backend API for TAB 'search' (#2389)

Browse files

### What problem does this PR solve?
#2247

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/chunk_app.py CHANGED
@@ -261,7 +261,7 @@ def retrieval_test():
261
  kb_id = req["kb_id"]
262
  if isinstance(kb_id, str): kb_id = [kb_id]
263
  doc_ids = req.get("doc_ids", [])
264
- similarity_threshold = float(req.get("similarity_threshold", 0.2))
265
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
266
  top = int(req.get("top_k", 1024))
267
 
 
261
  kb_id = req["kb_id"]
262
  if isinstance(kb_id, str): kb_id = [kb_id]
263
  doc_ids = req.get("doc_ids", [])
264
+ similarity_threshold = float(req.get("similarity_threshold", 0.0))
265
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
266
  top = int(req.get("top_k", 1024))
267
 
api/apps/conversation_app.py CHANGED
@@ -15,8 +15,8 @@
15
  #
16
  import json
17
  import re
 
18
  from copy import deepcopy
19
-
20
  from api.db.services.user_service import UserTenantService
21
  from flask import request, Response
22
  from flask_login import login_required, current_user
@@ -333,6 +333,8 @@ def mindmap():
333
  0.3, 0.3, aggs=False)
334
  mindmap = MindMapExtractor(chat_mdl)
335
  mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
 
 
336
  return get_json_result(data=mind_map)
337
 
338
 
 
15
  #
16
  import json
17
  import re
18
+ import traceback
19
  from copy import deepcopy
 
20
  from api.db.services.user_service import UserTenantService
21
  from flask import request, Response
22
  from flask_login import login_required, current_user
 
333
  0.3, 0.3, aggs=False)
334
  mindmap = MindMapExtractor(chat_mdl)
335
  mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
336
+ if "error" in mind_map:
337
+ return server_error_response(Exception(mind_map["error"]))
338
  return get_json_result(data=mind_map)
339
 
340
 
api/db/services/dialog_service.py CHANGED
@@ -218,7 +218,7 @@ def chat(dialog, messages, stream=True, **kwargs):
218
  for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
219
  answer = ans
220
  delta_ans = ans[len(last_ans):]
221
- if num_tokens_from_string(delta_ans) < 12:
222
  continue
223
  last_ans = answer
224
  yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
@@ -404,7 +404,6 @@ def rewrite(tenant_id, llm_id, question):
404
 
405
 
406
  def tts(tts_mdl, text):
407
- return
408
  if not tts_mdl or not text: return
409
  bin = b""
410
  for chunk in tts_mdl.tts(text):
 
218
  for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
219
  answer = ans
220
  delta_ans = ans[len(last_ans):]
221
+ if num_tokens_from_string(delta_ans) < 16:
222
  continue
223
  last_ans = answer
224
  yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
 
404
 
405
 
406
  def tts(tts_mdl, text):
 
407
  if not tts_mdl or not text: return
408
  bin = b""
409
  for chunk in tts_mdl.tts(text):
graphrag/mind_map_extractor.py CHANGED
@@ -107,7 +107,7 @@ class MindMapExtractor:
107
  res.append(_.result())
108
 
109
  if not res:
110
- return MindMapResult(output={"root":{}})
111
 
112
  merge_json = reduce(self._merge, res)
113
  if len(merge_json.keys()) > 1:
 
107
  res.append(_.result())
108
 
109
  if not res:
110
+ return MindMapResult(output={"id": "root", "children": []})
111
 
112
  merge_json = reduce(self._merge, res)
113
  if len(merge_json.keys()) > 1:
rag/llm/embedding_model.py CHANGED
@@ -15,7 +15,7 @@
15
  #
16
  import re
17
  from typing import Optional
18
- import threading
19
  import requests
20
  from huggingface_hub import snapshot_download
21
  from openai.lib.azure import AzureOpenAI
 
15
  #
16
  import re
17
  from typing import Optional
18
+ import threading
19
  import requests
20
  from huggingface_hub import snapshot_download
21
  from openai.lib.azure import AzureOpenAI
rag/nlp/search.py CHANGED
@@ -224,6 +224,8 @@ class Dealer:
224
  def insert_citations(self, answer, chunks, chunk_v,
225
  embd_mdl, tkweight=0.1, vtweight=0.9):
226
  assert len(chunks) == len(chunk_v)
 
 
227
  pieces = re.split(r"(```)", answer)
228
  if len(pieces) >= 3:
229
  i = 0
@@ -263,7 +265,7 @@ class Dealer:
263
 
264
  ans_v, _ = embd_mdl.encode(pieces_)
265
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
266
- len(ans_v[0]), len(chunk_v[0]))
267
 
268
  chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split(" ")
269
  for ck in chunks]
@@ -360,29 +362,33 @@ class Dealer:
360
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
361
  if not question:
362
  return ranks
363
- req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
 
364
  "question": question, "vector": True, "topk": top,
365
  "similarity": similarity_threshold,
366
  "available_int": 1}
 
 
 
367
  sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
 
368
 
369
- if rerank_mdl:
370
- sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
371
- sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
 
 
 
 
 
372
  else:
373
- sim, tsim, vsim = self.rerank(
374
- sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
375
- idx = np.argsort(sim * -1)
376
 
377
  dim = len(sres.query_vector)
378
- start_idx = (page - 1) * page_size
379
  for i in idx:
380
  if sim[i] < similarity_threshold:
381
  break
382
- ranks["total"] += 1
383
- start_idx -= 1
384
- if start_idx >= 0:
385
- continue
386
  if len(ranks["chunks"]) >= page_size:
387
  if aggs:
388
  continue
@@ -406,7 +412,10 @@ class Dealer:
406
  "positions": sres.field[id].get("position_int", "").split("\t")
407
  }
408
  if highlight:
409
- d["highlight"] = rmSpace(sres.highlight[id])
 
 
 
410
  if len(d["positions"]) % 5 == 0:
411
  poss = []
412
  for i in range(0, len(d["positions"]), 5):
 
224
  def insert_citations(self, answer, chunks, chunk_v,
225
  embd_mdl, tkweight=0.1, vtweight=0.9):
226
  assert len(chunks) == len(chunk_v)
227
+ if not chunks:
228
+ return answer, set([])
229
  pieces = re.split(r"(```)", answer)
230
  if len(pieces) >= 3:
231
  i = 0
 
265
 
266
  ans_v, _ = embd_mdl.encode(pieces_)
267
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
268
+ len(ans_v[0]), len(chunk_v[0]))
269
 
270
  chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split(" ")
271
  for ck in chunks]
 
362
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
363
  if not question:
364
  return ranks
365
+ RERANK_PAGE_LIMIT = 3
366
+ req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size*RERANK_PAGE_LIMIT,
367
  "question": question, "vector": True, "topk": top,
368
  "similarity": similarity_threshold,
369
  "available_int": 1}
370
+ if page > RERANK_PAGE_LIMIT:
371
+ req["page"] = page
372
+ req["size"] = page_size
373
  sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
374
+ ranks["total"] = sres.total
375
 
376
+ if page <= RERANK_PAGE_LIMIT:
377
+ if rerank_mdl:
378
+ sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
379
+ sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
380
+ else:
381
+ sim, tsim, vsim = self.rerank(
382
+ sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
383
+ idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size]
384
  else:
385
+ sim = tsim = vsim = [1]*len(sres.ids)
386
+ idx = list(range(len(sres.ids)))
 
387
 
388
  dim = len(sres.query_vector)
 
389
  for i in idx:
390
  if sim[i] < similarity_threshold:
391
  break
 
 
 
 
392
  if len(ranks["chunks"]) >= page_size:
393
  if aggs:
394
  continue
 
412
  "positions": sres.field[id].get("position_int", "").split("\t")
413
  }
414
  if highlight:
415
+ if id in sres.highlight:
416
+ d["highlight"] = rmSpace(sres.highlight[id])
417
+ else:
418
+ d["highlight"] = d["content_with_weight"]
419
  if len(d["positions"]) % 5 == 0:
420
  poss = []
421
  for i in range(0, len(d["positions"]), 5):