Kevin Hu commited on
Commit
6a49fcd
·
1 Parent(s): 78eb735

Let ThreadPool exit gracefully. (#3653)

Browse files

### What problem does this PR solve?

#3646

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

agent/component/crawler.py CHANGED
@@ -65,7 +65,3 @@ class Crawler(ComponentBase, ABC):
65
  elif self._param.extract_type == 'content':
66
  result.extracted_content
67
  return result.markdown
68
-
69
-
70
-
71
-
 
65
  elif self._param.extract_type == 'content':
66
  result.extracted_content
67
  return result.markdown
 
 
 
 
graphrag/index.py CHANGED
@@ -64,27 +64,27 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
64
  BATCH_SIZE=4
65
  texts, graphs = [], []
66
  cnt = 0
67
- threads = []
68
  max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
69
- exe = ThreadPoolExecutor(max_workers=max_workers)
70
- for i in range(len(chunks)):
71
- tkn_cnt = num_tokens_from_string(chunks[i])
72
- if cnt+tkn_cnt >= left_token_count and texts:
 
 
 
 
 
 
 
 
73
  for b in range(0, len(texts), BATCH_SIZE):
74
  threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
75
- texts = []
76
- cnt = 0
77
- texts.append(chunks[i])
78
- cnt += tkn_cnt
79
- if texts:
80
- for b in range(0, len(texts), BATCH_SIZE):
81
- threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
82
-
83
- callback(0.5, "Extracting entities.")
84
- graphs = []
85
- for i, _ in enumerate(threads):
86
- graphs.append(_.result().output)
87
- callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
88
 
89
  graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
90
  er = EntityResolution(llm_bdl)
 
64
  BATCH_SIZE=4
65
  texts, graphs = [], []
66
  cnt = 0
 
67
  max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
68
+ with ThreadPoolExecutor(max_workers=max_workers) as exe:
69
+ threads = []
70
+ for i in range(len(chunks)):
71
+ tkn_cnt = num_tokens_from_string(chunks[i])
72
+ if cnt+tkn_cnt >= left_token_count and texts:
73
+ for b in range(0, len(texts), BATCH_SIZE):
74
+ threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
75
+ texts = []
76
+ cnt = 0
77
+ texts.append(chunks[i])
78
+ cnt += tkn_cnt
79
+ if texts:
80
  for b in range(0, len(texts), BATCH_SIZE):
81
  threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
82
+
83
+ callback(0.5, "Extracting entities.")
84
+ graphs = []
85
+ for i, _ in enumerate(threads):
86
+ graphs.append(_.result().output)
87
+ callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
 
 
 
 
 
 
 
88
 
89
  graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
90
  er = EntityResolution(llm_bdl)
graphrag/mind_map_extractor.py CHANGED
@@ -88,26 +88,26 @@ class MindMapExtractor:
88
  prompt_variables = {}
89
 
90
  try:
91
- max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
92
- exe = ThreadPoolExecutor(max_workers=max_workers)
93
- threads = []
94
- token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
95
- texts = []
96
  res = []
97
- cnt = 0
98
- for i in range(len(sections)):
99
- section_cnt = num_tokens_from_string(sections[i])
100
- if cnt + section_cnt >= token_count and texts:
 
 
 
 
 
 
 
 
 
 
 
101
  threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
102
- texts = []
103
- cnt = 0
104
- texts.append(sections[i])
105
- cnt += section_cnt
106
- if texts:
107
- threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
108
-
109
- for i, _ in enumerate(threads):
110
- res.append(_.result())
111
 
112
  if not res:
113
  return MindMapResult(output={"id": "root", "children": []})
 
88
  prompt_variables = {}
89
 
90
  try:
 
 
 
 
 
91
  res = []
92
+ max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
93
+ with ThreadPoolExecutor(max_workers=max_workers) as exe:
94
+ threads = []
95
+ token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
96
+ texts = []
97
+ cnt = 0
98
+ for i in range(len(sections)):
99
+ section_cnt = num_tokens_from_string(sections[i])
100
+ if cnt + section_cnt >= token_count and texts:
101
+ threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
102
+ texts = []
103
+ cnt = 0
104
+ texts.append(sections[i])
105
+ cnt += section_cnt
106
+ if texts:
107
  threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
108
+
109
+ for i, _ in enumerate(threads):
110
+ res.append(_.result())
 
 
 
 
 
 
111
 
112
  if not res:
113
  return MindMapResult(output={"id": "root", "children": []})
rag/llm/chat_model.py CHANGED
@@ -366,7 +366,7 @@ class OllamaChat(Base):
366
  keep_alive=-1
367
  )
368
  ans = response["message"]["content"].strip()
369
- return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
370
  except Exception as e:
371
  return "**ERROR**: " + str(e), 0
372
 
 
366
  keep_alive=-1
367
  )
368
  ans = response["message"]["content"].strip()
369
+ return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
370
  except Exception as e:
371
  return "**ERROR**: " + str(e), 0
372
 
rag/svr/task_executor.py CHANGED
@@ -492,6 +492,7 @@ def report_status():
492
  logging.exception("report_status got exception")
493
  time.sleep(30)
494
 
 
495
  def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
496
  msg = ""
497
  if dump_full:
@@ -508,6 +509,7 @@ def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapsho
508
  msg += '\n'.join(stat.traceback.format())
509
  logging.info(msg)
510
 
 
511
  def main():
512
  settings.init_settings()
513
  background_thread = threading.Thread(target=report_status)
 
492
  logging.exception("report_status got exception")
493
  time.sleep(30)
494
 
495
+
496
  def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
497
  msg = ""
498
  if dump_full:
 
509
  msg += '\n'.join(stat.traceback.format())
510
  logging.info(msg)
511
 
512
+
513
  def main():
514
  settings.init_settings()
515
  background_thread = threading.Thread(target=report_status)