|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import collections |
|
import os |
|
import re |
|
import traceback |
|
from typing import Any |
|
from concurrent.futures import ThreadPoolExecutor |
|
from dataclasses import dataclass |
|
|
|
from graphrag.extractor import Extractor |
|
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT |
|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements |
|
from rag.llm.chat_model import Base as CompletionLLM |
|
import markdown_to_json |
|
from functools import reduce |
|
from rag.utils import num_tokens_from_string |
|
|
|
|
|
@dataclass |
|
class MindMapResult: |
|
"""Unipartite Mind Graph result class definition.""" |
|
output: dict |
|
|
|
|
|
class MindMapExtractor(Extractor): |
|
_input_text_key: str |
|
_mind_map_prompt: str |
|
_on_error: ErrorHandlerFn |
|
|
|
def __init__( |
|
self, |
|
llm_invoker: CompletionLLM, |
|
prompt: str | None = None, |
|
input_text_key: str | None = None, |
|
on_error: ErrorHandlerFn | None = None, |
|
): |
|
"""Init method definition.""" |
|
|
|
self._llm = llm_invoker |
|
self._input_text_key = input_text_key or "input_text" |
|
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT |
|
self._on_error = on_error or (lambda _e, _s, _d: None) |
|
|
|
def _key(self, k): |
|
return re.sub(r"\*+", "", k) |
|
|
|
def _be_children(self, obj: dict, keyset: set): |
|
if isinstance(obj, str): |
|
obj = [obj] |
|
if isinstance(obj, list): |
|
keyset.update(obj) |
|
obj = [re.sub(r"\*+", "", i) for i in obj] |
|
return [{"id": i, "children": []} for i in obj if i] |
|
arr = [] |
|
for k, v in obj.items(): |
|
k = self._key(k) |
|
if k and k not in keyset: |
|
keyset.add(k) |
|
arr.append( |
|
{ |
|
"id": k, |
|
"children": self._be_children(v, keyset) |
|
} |
|
) |
|
return arr |
|
|
|
def __call__( |
|
self, sections: list[str], prompt_variables: dict[str, Any] | None = None |
|
) -> MindMapResult: |
|
"""Call method definition.""" |
|
if prompt_variables is None: |
|
prompt_variables = {} |
|
|
|
try: |
|
res = [] |
|
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12)) |
|
with ThreadPoolExecutor(max_workers=max_workers) as exe: |
|
threads = [] |
|
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) |
|
texts = [] |
|
cnt = 0 |
|
for i in range(len(sections)): |
|
section_cnt = num_tokens_from_string(sections[i]) |
|
if cnt + section_cnt >= token_count and texts: |
|
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) |
|
texts = [] |
|
cnt = 0 |
|
texts.append(sections[i]) |
|
cnt += section_cnt |
|
if texts: |
|
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) |
|
|
|
for i, _ in enumerate(threads): |
|
res.append(_.result()) |
|
|
|
if not res: |
|
return MindMapResult(output={"id": "root", "children": []}) |
|
|
|
merge_json = reduce(self._merge, res) |
|
if len(merge_json) > 1: |
|
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)] |
|
keyset = set(i for i in keys if i) |
|
merge_json = { |
|
"id": "root", |
|
"children": [ |
|
{ |
|
"id": self._key(k), |
|
"children": self._be_children(v, keyset) |
|
} |
|
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k) |
|
] |
|
} |
|
else: |
|
k = self._key(list(merge_json.keys())[0]) |
|
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})} |
|
|
|
except Exception as e: |
|
logging.exception("error mind graph") |
|
self._on_error( |
|
e, |
|
traceback.format_exc(), None |
|
) |
|
merge_json = {"error": str(e)} |
|
|
|
return MindMapResult(output=merge_json) |
|
|
|
def _merge(self, d1, d2): |
|
for k in d1: |
|
if k in d2: |
|
if isinstance(d1[k], dict) and isinstance(d2[k], dict): |
|
self._merge(d1[k], d2[k]) |
|
elif isinstance(d1[k], list) and isinstance(d2[k], list): |
|
d2[k].extend(d1[k]) |
|
else: |
|
d2[k] = d1[k] |
|
else: |
|
d2[k] = d1[k] |
|
|
|
return d2 |
|
|
|
def _list_to_kv(self, data): |
|
for key, value in data.items(): |
|
if isinstance(value, dict): |
|
self._list_to_kv(value) |
|
elif isinstance(value, list): |
|
new_value = {} |
|
for i in range(len(value)): |
|
if isinstance(value[i], list) and i > 0: |
|
new_value[value[i - 1]] = value[i][0] |
|
data[key] = new_value |
|
else: |
|
continue |
|
return data |
|
|
|
def _todict(self, layer: collections.OrderedDict): |
|
to_ret = layer |
|
if isinstance(layer, collections.OrderedDict): |
|
to_ret = dict(layer) |
|
|
|
try: |
|
for key, value in to_ret.items(): |
|
to_ret[key] = self._todict(value) |
|
except AttributeError: |
|
pass |
|
|
|
return self._list_to_kv(to_ret) |
|
|
|
def _process_document( |
|
self, text: str, prompt_variables: dict[str, str] |
|
) -> str: |
|
variables = { |
|
**prompt_variables, |
|
self._input_text_key: text, |
|
} |
|
text = perform_variable_replacements(self._mind_map_prompt, variables=variables) |
|
gen_conf = {"temperature": 0.5} |
|
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) |
|
response = re.sub(r"```[^\n]*", "", response) |
|
logging.debug(response) |
|
logging.debug(self._todict(markdown_to_json.dictify(response))) |
|
return self._todict(markdown_to_json.dictify(response)) |
|
|