Kevin Hu
commited on
Commit
·
6a49fcd
1
Parent(s):
78eb735
Let ThreadPool exit gracefully. (#3653)
Browse files### What problem does this PR solve?
#3646
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- agent/component/crawler.py +0 -4
- graphrag/index.py +18 -18
- graphrag/mind_map_extractor.py +18 -18
- rag/llm/chat_model.py +1 -1
- rag/svr/task_executor.py +2 -0
agent/component/crawler.py
CHANGED
@@ -65,7 +65,3 @@ class Crawler(ComponentBase, ABC):
|
|
65 |
elif self._param.extract_type == 'content':
|
66 |
result.extracted_content
|
67 |
return result.markdown
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
65 |
elif self._param.extract_type == 'content':
|
66 |
result.extracted_content
|
67 |
return result.markdown
|
|
|
|
|
|
|
|
graphrag/index.py
CHANGED
@@ -64,27 +64,27 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
|
|
64 |
BATCH_SIZE=4
|
65 |
texts, graphs = [], []
|
66 |
cnt = 0
|
67 |
-
threads = []
|
68 |
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
for b in range(0, len(texts), BATCH_SIZE):
|
74 |
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
|
82 |
-
|
83 |
-
callback(0.5, "Extracting entities.")
|
84 |
-
graphs = []
|
85 |
-
for i, _ in enumerate(threads):
|
86 |
-
graphs.append(_.result().output)
|
87 |
-
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
|
88 |
|
89 |
graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
|
90 |
er = EntityResolution(llm_bdl)
|
|
|
64 |
BATCH_SIZE=4
|
65 |
texts, graphs = [], []
|
66 |
cnt = 0
|
|
|
67 |
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
|
68 |
+
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
69 |
+
threads = []
|
70 |
+
for i in range(len(chunks)):
|
71 |
+
tkn_cnt = num_tokens_from_string(chunks[i])
|
72 |
+
if cnt+tkn_cnt >= left_token_count and texts:
|
73 |
+
for b in range(0, len(texts), BATCH_SIZE):
|
74 |
+
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
|
75 |
+
texts = []
|
76 |
+
cnt = 0
|
77 |
+
texts.append(chunks[i])
|
78 |
+
cnt += tkn_cnt
|
79 |
+
if texts:
|
80 |
for b in range(0, len(texts), BATCH_SIZE):
|
81 |
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
|
82 |
+
|
83 |
+
callback(0.5, "Extracting entities.")
|
84 |
+
graphs = []
|
85 |
+
for i, _ in enumerate(threads):
|
86 |
+
graphs.append(_.result().output)
|
87 |
+
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
|
90 |
er = EntityResolution(llm_bdl)
|
graphrag/mind_map_extractor.py
CHANGED
@@ -88,26 +88,26 @@ class MindMapExtractor:
|
|
88 |
prompt_variables = {}
|
89 |
|
90 |
try:
|
91 |
-
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
|
92 |
-
exe = ThreadPoolExecutor(max_workers=max_workers)
|
93 |
-
threads = []
|
94 |
-
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
95 |
-
texts = []
|
96 |
res = []
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
cnt += section_cnt
|
106 |
-
if texts:
|
107 |
-
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
108 |
-
|
109 |
-
for i, _ in enumerate(threads):
|
110 |
-
res.append(_.result())
|
111 |
|
112 |
if not res:
|
113 |
return MindMapResult(output={"id": "root", "children": []})
|
|
|
88 |
prompt_variables = {}
|
89 |
|
90 |
try:
|
|
|
|
|
|
|
|
|
|
|
91 |
res = []
|
92 |
+
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
|
93 |
+
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
94 |
+
threads = []
|
95 |
+
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
96 |
+
texts = []
|
97 |
+
cnt = 0
|
98 |
+
for i in range(len(sections)):
|
99 |
+
section_cnt = num_tokens_from_string(sections[i])
|
100 |
+
if cnt + section_cnt >= token_count and texts:
|
101 |
+
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
102 |
+
texts = []
|
103 |
+
cnt = 0
|
104 |
+
texts.append(sections[i])
|
105 |
+
cnt += section_cnt
|
106 |
+
if texts:
|
107 |
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
108 |
+
|
109 |
+
for i, _ in enumerate(threads):
|
110 |
+
res.append(_.result())
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
if not res:
|
113 |
return MindMapResult(output={"id": "root", "children": []})
|
rag/llm/chat_model.py
CHANGED
@@ -366,7 +366,7 @@ class OllamaChat(Base):
|
|
366 |
keep_alive=-1
|
367 |
)
|
368 |
ans = response["message"]["content"].strip()
|
369 |
-
return ans, response
|
370 |
except Exception as e:
|
371 |
return "**ERROR**: " + str(e), 0
|
372 |
|
|
|
366 |
keep_alive=-1
|
367 |
)
|
368 |
ans = response["message"]["content"].strip()
|
369 |
+
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
370 |
except Exception as e:
|
371 |
return "**ERROR**: " + str(e), 0
|
372 |
|
rag/svr/task_executor.py
CHANGED
@@ -492,6 +492,7 @@ def report_status():
|
|
492 |
logging.exception("report_status got exception")
|
493 |
time.sleep(30)
|
494 |
|
|
|
495 |
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
|
496 |
msg = ""
|
497 |
if dump_full:
|
@@ -508,6 +509,7 @@ def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapsho
|
|
508 |
msg += '\n'.join(stat.traceback.format())
|
509 |
logging.info(msg)
|
510 |
|
|
|
511 |
def main():
|
512 |
settings.init_settings()
|
513 |
background_thread = threading.Thread(target=report_status)
|
|
|
492 |
logging.exception("report_status got exception")
|
493 |
time.sleep(30)
|
494 |
|
495 |
+
|
496 |
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
|
497 |
msg = ""
|
498 |
if dump_full:
|
|
|
509 |
msg += '\n'.join(stat.traceback.format())
|
510 |
logging.info(msg)
|
511 |
|
512 |
+
|
513 |
def main():
|
514 |
settings.init_settings()
|
515 |
background_thread = threading.Thread(target=report_status)
|