ragflow / graphrag /community_reports_extractor.py
H
Fix graphrag : "role" user (#2273)
eb6e194
raw
history blame
5.06 kB
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import json
import logging
import re
import traceback
from dataclasses import dataclass
from typing import Any, List, Callable
import networkx as nx
import pandas as pd
from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
log = logging.getLogger(__name__)
@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""
output: List[str]
structured_output: List[dict]
class CommunityReportsExtractor:
"""Community reports extractor class definition."""
_llm: CompletionLLM
_extraction_prompt: str
_output_formatter_prompt: str
_on_error: ErrorHandlerFn
_max_report_length: int
def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_report_length: int | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []
res_dict = []
over, token_count = 0, 0
st = timer()
for level, comm in communities.items():
for cm_id, ents in comm.items():
weight = ents["weight"]
ents = ents["nodes"]
ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
prompt_variables = {
"entity_df": ent_df.to_csv(index_label="id"),
"relation_df": rela_df.to_csv(index_label="id")
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3}
try:
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
print(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]): continue
response["weight"] = weight
response["entities"] = ents
except Exception as e:
print("ERROR: ", traceback.format_exc())
self._on_error(e, traceback.format_exc(), None)
continue
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
return CommunityReportsResult(
structured_output=res_dict,
output=res_str,
)
def _get_text_output(self, parsed_output: dict) -> str:
title = parsed_output.get("title", "Report")
summary = parsed_output.get("summary", "")
findings = parsed_output.get("findings", [])
def finding_summary(finding: dict):
if isinstance(finding, str):
return finding
return finding.get("summary")
def finding_explanation(finding: dict):
if isinstance(finding, str):
return ""
return finding.get("explanation")
report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
)
return f"# {title}\n\n{summary}\n\n{report_sections}"