Spaces:
Runtime error
Runtime error
"""General node utils.""" | |
import logging | |
from typing import List | |
from gpt_index.data_structs.node_v2 import DocumentRelationship, ImageNode, Node | |
from gpt_index.langchain_helpers.text_splitter import ( | |
TextSplit, | |
TextSplitter, | |
TokenTextSplitter, | |
) | |
from gpt_index.readers.schema.base import ImageDocument | |
from gpt_index.schema import BaseDocument | |
from gpt_index.utils import truncate_text | |
logger = logging.getLogger(__name__) | |
def get_text_splits_from_document( | |
document: BaseDocument, | |
text_splitter: TextSplitter, | |
include_extra_info: bool = True, | |
) -> List[TextSplit]: | |
"""Break the document into chunks with additional info.""" | |
# TODO: clean up since this only exists due to the diff w LangChain's TextSplitter | |
text_splits = [] | |
if isinstance(text_splitter, TokenTextSplitter): | |
# use this to extract extra information about the chunks | |
text_splits = text_splitter.split_text_with_overlaps( | |
document.get_text(), | |
extra_info_str=document.extra_info_str if include_extra_info else None, | |
) | |
else: | |
text_chunks = text_splitter.split_text( | |
document.get_text(), | |
) | |
text_splits = [TextSplit(text_chunk=text_chunk) for text_chunk in text_chunks] | |
return text_splits | |
def get_nodes_from_document( | |
document: BaseDocument, | |
text_splitter: TextSplitter, | |
include_extra_info: bool = True, | |
include_prev_next_rel: bool = False, | |
) -> List[Node]: | |
"""Get nodes from document.""" | |
text_splits = get_text_splits_from_document( | |
document=document, | |
text_splitter=text_splitter, | |
include_extra_info=include_extra_info, | |
) | |
nodes: List[Node] = [] | |
index_counter = 0 | |
for i, text_split in enumerate(text_splits): | |
text_chunk = text_split.text_chunk | |
logger.debug(f"> Adding chunk: {truncate_text(text_chunk, 50)}") | |
index_pos_info = None | |
if text_split.num_char_overlap is not None: | |
index_pos_info = { | |
# NOTE: start is inclusive, end is exclusive | |
"start": index_counter - text_split.num_char_overlap, | |
"end": index_counter - text_split.num_char_overlap + len(text_chunk), | |
} | |
index_counter += len(text_chunk) + 1 | |
if isinstance(document, ImageDocument): | |
image_node = ImageNode( | |
text=text_chunk, | |
embedding=document.embedding, | |
extra_info=document.extra_info if include_extra_info else None, | |
node_info=index_pos_info, | |
image=document.image, | |
relationships={DocumentRelationship.SOURCE: document.get_doc_id()}, | |
) | |
nodes.append(image_node) # type: ignore | |
else: | |
node = Node( | |
text=text_chunk, | |
embedding=document.embedding, | |
extra_info=document.extra_info if include_extra_info else None, | |
node_info=index_pos_info, | |
relationships={DocumentRelationship.SOURCE: document.get_doc_id()}, | |
) | |
nodes.append(node) | |
# if include_prev_next_rel, then add prev/next relationships | |
if include_prev_next_rel: | |
for i, node in enumerate(nodes): | |
if i > 0: | |
node.relationships[DocumentRelationship.PREVIOUS] = nodes[ | |
i - 1 | |
].get_doc_id() | |
if i < len(nodes) - 1: | |
node.relationships[DocumentRelationship.NEXT] = nodes[ | |
i + 1 | |
].get_doc_id() | |
return nodes | |