File size: 6,320 Bytes
47ec63e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
import tiktoken
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor(Extractor):
"""Unipartite graph extractor class definition."""
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
_entity_types_key: str
_input_text_key: str
_completion_delimiter_key: str
_entity_name_key: str
_input_descriptions_key: str
_extraction_prompt: str
_summarization_prompt: str
_loop_args: dict[str, Any]
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
entity_types_key: str | None = None,
completion_delimiter_key: str | None = None,
join_descriptions=True,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._join_descriptions = join_descriptions
self._input_text_key = input_text_key or "input_text"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._entity_types_key = entity_types_key or "entity_types"
self._extraction_prompt = GRAPH_EXTRACTION_PROMPT
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
# Construct the looping arguments
encoding = tiktoken.get_encoding("cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
# Wire defaults into the prompt variables
self._prompt_variables = {
"entity_types": entity_types,
self._tuple_delimiter_key: DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: DEFAULT_COMPLETION_DELIMITER,
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
}
def _process_single_content(self,
chunk_key_dp: tuple[str, str]
):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
variables = {
**self._prompt_variables,
self._input_text_key: content,
}
try:
gen_conf = {"temperature": 0.3}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "assistant", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._chat("", history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, self._loop_args)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None
|