ragflow / graphrag /mind_map_extractor.py
Kevin Hu
Cache the result from llm for graphrag and raptor (#4051)
758538f
raw
history blame
7 kB
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import collections
import os
import re
import traceback
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from graphrag.extractor import Extractor
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json
from functools import reduce
from rag.utils import num_tokens_from_string
@dataclass
class MindMapResult:
"""Unipartite Mind Graph result class definition."""
output: dict
class MindMapExtractor(Extractor):
_input_text_key: str
_mind_map_prompt: str
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
prompt: str | None = None,
input_text_key: str | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._input_text_key = input_text_key or "input_text"
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
def _key(self, k):
return re.sub(r"\*+", "", k)
def _be_children(self, obj: dict, keyset: set):
if isinstance(obj, str):
obj = [obj]
if isinstance(obj, list):
keyset.update(obj)
obj = [re.sub(r"\*+", "", i) for i in obj]
return [{"id": i, "children": []} for i in obj if i]
arr = []
for k, v in obj.items():
k = self._key(k)
if k and k not in keyset:
keyset.add(k)
arr.append(
{
"id": k,
"children": self._be_children(v, keyset)
}
)
return arr
def __call__(
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
) -> MindMapResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
try:
res = []
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = []
cnt = 0
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
for i, _ in enumerate(threads):
res.append(_.result())
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
keyset = set(i for i in keys if i)
merge_json = {
"id": "root",
"children": [
{
"id": self._key(k),
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
except Exception as e:
logging.exception("error mind graph")
self._on_error(
e,
traceback.format_exc(), None
)
merge_json = {"error": str(e)}
return MindMapResult(output=merge_json)
def _merge(self, d1, d2):
for k in d1:
if k in d2:
if isinstance(d1[k], dict) and isinstance(d2[k], dict):
self._merge(d1[k], d2[k])
elif isinstance(d1[k], list) and isinstance(d2[k], list):
d2[k].extend(d1[k])
else:
d2[k] = d1[k]
else:
d2[k] = d1[k]
return d2
def _list_to_kv(self, data):
for key, value in data.items():
if isinstance(value, dict):
self._list_to_kv(value)
elif isinstance(value, list):
new_value = {}
for i in range(len(value)):
if isinstance(value[i], list) and i > 0:
new_value[value[i - 1]] = value[i][0]
data[key] = new_value
else:
continue
return data
def _todict(self, layer: collections.OrderedDict):
to_ret = layer
if isinstance(layer, collections.OrderedDict):
to_ret = dict(layer)
try:
for key, value in to_ret.items():
to_ret[key] = self._todict(value)
except AttributeError:
pass
return self._list_to_kv(to_ret)
def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
variables = {
**prompt_variables,
self._input_text_key: text,
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))
return self._todict(markdown_to_json.dictify(response))