|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor |
|
import json |
|
from functools import reduce |
|
import networkx as nx |
|
from api.db import LLMType |
|
from api.db.services.llm_service import LLMBundle |
|
from api.db.services.user_service import TenantService |
|
from graphrag.community_reports_extractor import CommunityReportsExtractor |
|
from graphrag.entity_resolution import EntityResolution |
|
from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES |
|
from graphrag.mind_map_extractor import MindMapExtractor |
|
from rag.nlp import rag_tokenizer |
|
from rag.utils import num_tokens_from_string |
|
|
|
|
|
def graph_merge(g1, g2): |
|
g = g2.copy() |
|
for n, attr in g1.nodes(data=True): |
|
if n not in g2.nodes(): |
|
g.add_node(n, **attr) |
|
continue |
|
|
|
g.nodes[n]["weight"] += 1 |
|
if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0: |
|
g.nodes[n]["description"] += "\n" + attr["description"] |
|
|
|
for source, target, attr in g1.edges(data=True): |
|
if g.has_edge(source, target): |
|
g[source][target].update({"weight": attr["weight"]+1}) |
|
continue |
|
g.add_edge(source, target, **attr) |
|
|
|
for node_degree in g.degree: |
|
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) |
|
return g |
|
|
|
|
|
def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES): |
|
_, tenant = TenantService.get_by_id(tenant_id) |
|
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id) |
|
ext = GraphExtractor(llm_bdl) |
|
left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024 |
|
left_token_count = max(llm_bdl.max_length * 0.6, left_token_count) |
|
|
|
assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})" |
|
|
|
BATCH_SIZE=4 |
|
texts, graphs = [], [] |
|
cnt = 0 |
|
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50)) |
|
with ThreadPoolExecutor(max_workers=max_workers) as exe: |
|
threads = [] |
|
for i in range(len(chunks)): |
|
tkn_cnt = num_tokens_from_string(chunks[i]) |
|
if cnt+tkn_cnt >= left_token_count and texts: |
|
for b in range(0, len(texts), BATCH_SIZE): |
|
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) |
|
texts = [] |
|
cnt = 0 |
|
texts.append(chunks[i]) |
|
cnt += tkn_cnt |
|
if texts: |
|
for b in range(0, len(texts), BATCH_SIZE): |
|
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) |
|
|
|
callback(0.5, "Extracting entities.") |
|
graphs = [] |
|
for i, _ in enumerate(threads): |
|
graphs.append(_.result().output) |
|
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}") |
|
|
|
graph = reduce(graph_merge, graphs) if graphs else nx.Graph() |
|
er = EntityResolution(llm_bdl) |
|
graph = er(graph).output |
|
|
|
_chunks = chunks |
|
chunks = [] |
|
for n, attr in graph.nodes(data=True): |
|
if attr.get("rank", 0) == 0: |
|
logging.debug(f"Ignore entity: {n}") |
|
continue |
|
chunk = { |
|
"name_kwd": n, |
|
"important_kwd": [n], |
|
"title_tks": rag_tokenizer.tokenize(n), |
|
"content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False), |
|
"content_ltks": rag_tokenizer.tokenize(attr["description"]), |
|
"knowledge_graph_kwd": "entity", |
|
"rank_int": attr["rank"], |
|
"weight_int": attr["weight"] |
|
} |
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) |
|
chunks.append(chunk) |
|
|
|
callback(0.6, "Extracting community reports.") |
|
cr = CommunityReportsExtractor(llm_bdl) |
|
cr = cr(graph, callback=callback) |
|
for community, desc in zip(cr.structured_output, cr.output): |
|
chunk = { |
|
"title_tks": rag_tokenizer.tokenize(community["title"]), |
|
"content_with_weight": desc, |
|
"content_ltks": rag_tokenizer.tokenize(desc), |
|
"knowledge_graph_kwd": "community_report", |
|
"weight_flt": community["weight"], |
|
"entities_kwd": community["entities"], |
|
"important_kwd": community["entities"] |
|
} |
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) |
|
chunks.append(chunk) |
|
|
|
chunks.append( |
|
{ |
|
"content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2), |
|
"knowledge_graph_kwd": "graph" |
|
}) |
|
|
|
callback(0.75, "Extracting mind graph.") |
|
mindmap = MindMapExtractor(llm_bdl) |
|
mg = mindmap(_chunks).output |
|
if not len(mg.keys()): |
|
return chunks |
|
|
|
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2)) |
|
chunks.append( |
|
{ |
|
"content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2), |
|
"knowledge_graph_kwd": "mind_map" |
|
}) |
|
|
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|