# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License """ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ import logging import json import re import traceback from typing import Callable from dataclasses import dataclass import networkx as nx import pandas as pd from graphrag import leiden from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.extractor import Extractor from graphrag.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types from rag.utils import num_tokens_from_string from timeit import default_timer as timer @dataclass class CommunityReportsResult: """Community reports result class definition.""" output: list[str] structured_output: list[dict] class CommunityReportsExtractor(Extractor): """Community reports extractor class definition.""" _extraction_prompt: str _output_formatter_prompt: str _on_error: ErrorHandlerFn _max_report_length: int def __init__( self, llm_invoker: CompletionLLM, extraction_prompt: str | None = None, on_error: ErrorHandlerFn | None = None, max_report_length: int | None = None, ): """Init method definition.""" self._llm = llm_invoker self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT self._on_error = on_error or (lambda _e, _s, _d: None) self._max_report_length = max_report_length or 1500 def __call__(self, graph: nx.Graph, callback: Callable | None = None): communities: dict[str, dict[str, list]] = leiden.run(graph, {}) total = sum([len(comm.items()) for _, comm in communities.items()]) relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) res_str = [] res_dict = [] over, token_count = 0, 0 st = timer() for level, comm in communities.items(): for cm_id, ents in comm.items(): weight = ents["weight"] ents = ents["nodes"] ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents]) rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True) prompt_variables = { "entity_df": ent_df.to_csv(index_label="id"), "relation_df": rela_df.to_csv(index_label="id") } text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) gen_conf = {"temperature": 0.3} try: response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) token_count += num_tokens_from_string(text + response) response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"[^\}]*$", "", response) response = re.sub(r"\{\{", "{", response) response = re.sub(r"\}\}", "}", response) logging.debug(response) response = json.loads(response) if not dict_has_keys_with_types(response, [ ("title", str), ("summary", str), ("findings", list), ("rating", float), ("rating_explanation", str), ]): continue response["weight"] = weight response["entities"] = ents except Exception as e: logging.exception("CommunityReportsExtractor got exception") self._on_error(e, traceback.format_exc(), None) continue add_community_info2graph(graph, ents, response["title"]) res_str.append(self._get_text_output(response)) res_dict.append(response) over += 1 if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") return CommunityReportsResult( structured_output=res_dict, output=res_str, ) def _get_text_output(self, parsed_output: dict) -> str: title = parsed_output.get("title", "Report") summary = parsed_output.get("summary", "") findings = parsed_output.get("findings", []) def finding_summary(finding: dict): if isinstance(finding, str): return finding return finding.get("summary") def finding_explanation(finding: dict): if isinstance(finding, str): return "" return finding.get("explanation") report_sections = "\n\n".join( f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings ) return f"# {title}\n\n{summary}\n\n{report_sections}"