|
|
|
|
|
""" |
|
Reference: |
|
- [graphrag](https://github.com/microsoft/graphrag) |
|
""" |
|
|
|
import json |
|
from dataclasses import dataclass |
|
|
|
from graphrag.extractor import Extractor |
|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements |
|
from rag.llm.chat_model import Base as CompletionLLM |
|
|
|
from rag.utils import num_tokens_from_string |
|
|
|
SUMMARIZE_PROMPT = """ |
|
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. |
|
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. |
|
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. |
|
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. |
|
Make sure it is written in third person, and include the entity names so we the have full context. |
|
|
|
####### |
|
-Data- |
|
Entities: {entity_name} |
|
Description List: {description_list} |
|
####### |
|
Output: |
|
""" |
|
|
|
|
|
DEFAULT_MAX_INPUT_TOKENS = 4_000 |
|
|
|
DEFAULT_MAX_SUMMARY_LENGTH = 128 |
|
|
|
|
|
@dataclass |
|
class SummarizationResult: |
|
"""Unipartite graph extraction result class definition.""" |
|
|
|
items: str | tuple[str, str] |
|
description: str |
|
|
|
|
|
class SummarizeExtractor(Extractor): |
|
"""Unipartite graph extractor class definition.""" |
|
|
|
_entity_name_key: str |
|
_input_descriptions_key: str |
|
_summarization_prompt: str |
|
_on_error: ErrorHandlerFn |
|
_max_summary_length: int |
|
_max_input_tokens: int |
|
|
|
def __init__( |
|
self, |
|
llm_invoker: CompletionLLM, |
|
entity_name_key: str | None = None, |
|
input_descriptions_key: str | None = None, |
|
summarization_prompt: str | None = None, |
|
on_error: ErrorHandlerFn | None = None, |
|
max_summary_length: int | None = None, |
|
max_input_tokens: int | None = None, |
|
): |
|
"""Init method definition.""" |
|
|
|
self._llm = llm_invoker |
|
self._entity_name_key = entity_name_key or "entity_name" |
|
self._input_descriptions_key = input_descriptions_key or "description_list" |
|
|
|
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT |
|
self._on_error = on_error or (lambda _e, _s, _d: None) |
|
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH |
|
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS |
|
|
|
def __call__( |
|
self, |
|
items: str | tuple[str, str], |
|
descriptions: list[str], |
|
) -> SummarizationResult: |
|
"""Call method definition.""" |
|
result = "" |
|
if len(descriptions) == 0: |
|
result = "" |
|
if len(descriptions) == 1: |
|
result = descriptions[0] |
|
else: |
|
result = self._summarize_descriptions(items, descriptions) |
|
|
|
return SummarizationResult( |
|
items=items, |
|
description=result or "", |
|
) |
|
|
|
def _summarize_descriptions( |
|
self, items: str | tuple[str, str], descriptions: list[str] |
|
) -> str: |
|
"""Summarize descriptions into a single description.""" |
|
sorted_items = sorted(items) if isinstance(items, list) else items |
|
|
|
|
|
if not isinstance(descriptions, list): |
|
descriptions = [descriptions] |
|
|
|
|
|
usable_tokens = self._max_input_tokens - num_tokens_from_string( |
|
self._summarization_prompt |
|
) |
|
descriptions_collected = [] |
|
result = "" |
|
|
|
for i, description in enumerate(descriptions): |
|
usable_tokens -= num_tokens_from_string(description) |
|
descriptions_collected.append(description) |
|
|
|
|
|
if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( |
|
i == len(descriptions) - 1 |
|
): |
|
|
|
result = await self._summarize_descriptions_with_llm( |
|
sorted_items, descriptions_collected |
|
) |
|
|
|
|
|
if i != len(descriptions) - 1: |
|
descriptions_collected = [result] |
|
usable_tokens = ( |
|
self._max_input_tokens |
|
- num_tokens_from_string(self._summarization_prompt) |
|
- num_tokens_from_string(result) |
|
) |
|
|
|
return result |
|
|
|
def _summarize_descriptions_with_llm( |
|
self, items: str | tuple[str, str] | list[str], descriptions: list[str] |
|
): |
|
"""Summarize descriptions using the LLM.""" |
|
variables = { |
|
self._entity_name_key: json.dumps(items), |
|
self._input_descriptions_key: json.dumps(sorted(descriptions)), |
|
} |
|
text = perform_variable_replacements(self._summarization_prompt, variables=variables) |
|
return self._chat("", [{"role": "user", "content": text}]) |
|
|