Kevin Hu commited on
Commit
62a5517
·
1 Parent(s): 0e469cf

Rebuild graph when it's out of time. (#4607)

Browse files

### What problem does this PR solve?

#4543

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

api/db/services/dialog_service.py CHANGED
@@ -17,6 +17,7 @@ import logging
17
  import binascii
18
  import os
19
  import json
 
20
  import re
21
  from collections import defaultdict
22
  from copy import deepcopy
@@ -353,7 +354,7 @@ def chat(dialog, messages, stream=True, **kwargs):
353
  generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
354
 
355
  prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
356
- return {"answer": answer, "reference": refs, "prompt": prompt}
357
 
358
  if stream:
359
  last_ans = ""
@@ -795,5 +796,13 @@ Output:
795
  if kwd.find("**ERROR**") >= 0:
796
  raise Exception(kwd)
797
 
798
- kwd = re.sub(r".*?\{", "{", kwd)
799
- return json.loads(kwd)
 
 
 
 
 
 
 
 
 
17
  import binascii
18
  import os
19
  import json
20
+ import json_repair
21
  import re
22
  from collections import defaultdict
23
  from copy import deepcopy
 
354
  generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
355
 
356
  prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
357
+ return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)}
358
 
359
  if stream:
360
  last_ans = ""
 
796
  if kwd.find("**ERROR**") >= 0:
797
  raise Exception(kwd)
798
 
799
+ try:
800
+ return json_repair.loads(kwd)
801
+ except json_repair.JSONDecodeError:
802
+ try:
803
+ result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
804
+ result = '{' + result.split('{')[1].split('}')[0] + '}'
805
+ return json_repair.loads(result)
806
+ except Exception as e:
807
+ logging.exception(f"JSON parsing error: {result} -> {e}")
808
+ raise e
graphrag/search.py CHANGED
@@ -251,11 +251,11 @@ class KGSearch(Dealer):
251
  break
252
 
253
  if ents:
254
- ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv())
255
  else:
256
  ents = ""
257
  if relas:
258
- relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv())
259
  else:
260
  relas = ""
261
 
@@ -296,7 +296,7 @@ class KGSearch(Dealer):
296
 
297
  if not txts:
298
  return ""
299
- return "\n-Community Report-\n" + "\n".join(txts)
300
 
301
 
302
  if __name__ == "__main__":
 
251
  break
252
 
253
  if ents:
254
+ ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv())
255
  else:
256
  ents = ""
257
  if relas:
258
+ relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv())
259
  else:
260
  relas = ""
261
 
 
296
 
297
  if not txts:
298
  return ""
299
+ return "\n---- Community Report ----\n" + "\n".join(txts)
300
 
301
 
302
  if __name__ == "__main__":
graphrag/utils.py CHANGED
@@ -23,6 +23,7 @@ from networkx.readwrite import json_graph
23
 
24
  from api import settings
25
  from rag.nlp import search, rag_tokenizer
 
26
  from rag.utils.redis_conn import REDIS_CONN
27
 
28
  ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
@@ -363,7 +364,7 @@ def get_graph(tenant_id, kb_id):
363
  res.field[id]["source_id"]
364
  except Exception:
365
  continue
366
- return None, None
367
 
368
 
369
  def set_graph(tenant_id, kb_id, graph, docids):
@@ -517,3 +518,36 @@ def flat_uniq_list(arr, key):
517
  res.append(a)
518
  return list(set(res))
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  from api import settings
25
  from rag.nlp import search, rag_tokenizer
26
+ from rag.utils.doc_store_conn import OrderByExpr
27
  from rag.utils.redis_conn import REDIS_CONN
28
 
29
  ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
 
364
  res.field[id]["source_id"]
365
  except Exception:
366
  continue
367
+ return rebuild_graph(tenant_id, kb_id)
368
 
369
 
370
  def set_graph(tenant_id, kb_id, graph, docids):
 
518
  res.append(a)
519
  return list(set(res))
520
 
521
+
522
+ def rebuild_graph(tenant_id, kb_id):
523
+ graph = nx.Graph()
524
+ src_ids = []
525
+ flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
526
+ bs = 256
527
+ for i in range(0, 10000000, bs):
528
+ es_res = settings.docStoreConn.search(flds, [],
529
+ {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
530
+ [],
531
+ OrderByExpr(),
532
+ i, bs, search.index_name(tenant_id), [kb_id]
533
+ )
534
+ tot = settings.docStoreConn.getTotal(es_res)
535
+ if tot == 0:
536
+ return None, None
537
+
538
+ es_res = settings.docStoreConn.getFields(es_res, flds)
539
+ for id, d in es_res.items():
540
+ src_ids.extend(d.get("source_id", []))
541
+ if d["knowledge_graph_kwd"] == "entity":
542
+ graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
543
+ else:
544
+ graph.add_edge(
545
+ d["from_entity_kwd"],
546
+ d["to_entity_kwd"],
547
+ weight=int(d["weight_int"])
548
+ )
549
+
550
+ if len(es_res.keys()) < 128:
551
+ return graph, list(set(src_ids))
552
+
553
+ return graph, list(set(src_ids))
rag/nlp/search.py CHANGED
@@ -483,4 +483,4 @@ class Dealer:
483
  cnt = np.sum([c for _, c in aggs])
484
  tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
485
  key=lambda x: x[1] * -1)[:topn_tags]
486
- return {a: c for a, c in tag_fea if c > 0}
 
483
  cnt = np.sum([c for _, c in aggs])
484
  tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
485
  key=lambda x: x[1] * -1)[:topn_tags]
486
+ return {a: max(1, c) for a, c in tag_fea}
rag/svr/task_executor.py CHANGED
@@ -327,8 +327,10 @@ def build_chunks(task, progress_callback):
327
  random.choices(examples, k=2) if len(examples)>2 else examples,
328
  topn=topn_tags)
329
  if cached:
330
- set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
331
- d[TAG_FLD] = json.loads(cached)
 
 
332
 
333
  progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
334
 
 
327
  random.choices(examples, k=2) if len(examples)>2 else examples,
328
  topn=topn_tags)
329
  if cached:
330
+ cached = json.dumps(cached)
331
+ if cached:
332
+ set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
333
+ d[TAG_FLD] = json.loads(cached)
334
 
335
  progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
336