|
|
|
|
|
""" |
|
Reference: |
|
- [graphrag](https://github.com/microsoft/graphrag) |
|
""" |
|
|
|
import logging |
|
import json |
|
import re |
|
import traceback |
|
from typing import Callable |
|
from dataclasses import dataclass |
|
import networkx as nx |
|
import pandas as pd |
|
from graphrag import leiden |
|
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT |
|
from graphrag.extractor import Extractor |
|
from graphrag.leiden import add_community_info2graph |
|
from rag.llm.chat_model import Base as CompletionLLM |
|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types |
|
from rag.utils import num_tokens_from_string |
|
from timeit import default_timer as timer |
|
|
|
|
|
@dataclass |
|
class CommunityReportsResult: |
|
"""Community reports result class definition.""" |
|
|
|
output: list[str] |
|
structured_output: list[dict] |
|
|
|
|
|
class CommunityReportsExtractor(Extractor): |
|
"""Community reports extractor class definition.""" |
|
|
|
_extraction_prompt: str |
|
_output_formatter_prompt: str |
|
_on_error: ErrorHandlerFn |
|
_max_report_length: int |
|
|
|
def __init__( |
|
self, |
|
llm_invoker: CompletionLLM, |
|
extraction_prompt: str | None = None, |
|
on_error: ErrorHandlerFn | None = None, |
|
max_report_length: int | None = None, |
|
): |
|
"""Init method definition.""" |
|
self._llm = llm_invoker |
|
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT |
|
self._on_error = on_error or (lambda _e, _s, _d: None) |
|
self._max_report_length = max_report_length or 1500 |
|
|
|
def __call__(self, graph: nx.Graph, callback: Callable | None = None): |
|
communities: dict[str, dict[str, list]] = leiden.run(graph, {}) |
|
total = sum([len(comm.items()) for _, comm in communities.items()]) |
|
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) |
|
res_str = [] |
|
res_dict = [] |
|
over, token_count = 0, 0 |
|
st = timer() |
|
for level, comm in communities.items(): |
|
for cm_id, ents in comm.items(): |
|
weight = ents["weight"] |
|
ents = ents["nodes"] |
|
ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents]) |
|
rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True) |
|
|
|
prompt_variables = { |
|
"entity_df": ent_df.to_csv(index_label="id"), |
|
"relation_df": rela_df.to_csv(index_label="id") |
|
} |
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) |
|
gen_conf = {"temperature": 0.3} |
|
try: |
|
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) |
|
token_count += num_tokens_from_string(text + response) |
|
response = re.sub(r"^[^\{]*", "", response) |
|
response = re.sub(r"[^\}]*$", "", response) |
|
response = re.sub(r"\{\{", "{", response) |
|
response = re.sub(r"\}\}", "}", response) |
|
logging.debug(response) |
|
response = json.loads(response) |
|
if not dict_has_keys_with_types(response, [ |
|
("title", str), |
|
("summary", str), |
|
("findings", list), |
|
("rating", float), |
|
("rating_explanation", str), |
|
]): |
|
continue |
|
response["weight"] = weight |
|
response["entities"] = ents |
|
except Exception as e: |
|
logging.exception("CommunityReportsExtractor got exception") |
|
self._on_error(e, traceback.format_exc(), None) |
|
continue |
|
|
|
add_community_info2graph(graph, ents, response["title"]) |
|
res_str.append(self._get_text_output(response)) |
|
res_dict.append(response) |
|
over += 1 |
|
if callback: |
|
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") |
|
|
|
return CommunityReportsResult( |
|
structured_output=res_dict, |
|
output=res_str, |
|
) |
|
|
|
def _get_text_output(self, parsed_output: dict) -> str: |
|
title = parsed_output.get("title", "Report") |
|
summary = parsed_output.get("summary", "") |
|
findings = parsed_output.get("findings", []) |
|
|
|
def finding_summary(finding: dict): |
|
if isinstance(finding, str): |
|
return finding |
|
return finding.get("summary") |
|
|
|
def finding_explanation(finding: dict): |
|
if isinstance(finding, str): |
|
return "" |
|
return finding.get("explanation") |
|
|
|
report_sections = "\n\n".join( |
|
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings |
|
) |
|
return f"# {title}\n\n{summary}\n\n{report_sections}" |
|
|