binhnase04854's picture
first deploy
b699122
"""General node utils."""
import logging
from typing import List
from gpt_index.data_structs.node_v2 import DocumentRelationship, ImageNode, Node
from gpt_index.langchain_helpers.text_splitter import (
TextSplit,
TextSplitter,
TokenTextSplitter,
)
from gpt_index.readers.schema.base import ImageDocument
from gpt_index.schema import BaseDocument
from gpt_index.utils import truncate_text
logger = logging.getLogger(__name__)
def get_text_splits_from_document(
document: BaseDocument,
text_splitter: TextSplitter,
include_extra_info: bool = True,
) -> List[TextSplit]:
"""Break the document into chunks with additional info."""
# TODO: clean up since this only exists due to the diff w LangChain's TextSplitter
text_splits = []
if isinstance(text_splitter, TokenTextSplitter):
# use this to extract extra information about the chunks
text_splits = text_splitter.split_text_with_overlaps(
document.get_text(),
extra_info_str=document.extra_info_str if include_extra_info else None,
)
else:
text_chunks = text_splitter.split_text(
document.get_text(),
)
text_splits = [TextSplit(text_chunk=text_chunk) for text_chunk in text_chunks]
return text_splits
def get_nodes_from_document(
document: BaseDocument,
text_splitter: TextSplitter,
include_extra_info: bool = True,
include_prev_next_rel: bool = False,
) -> List[Node]:
"""Get nodes from document."""
text_splits = get_text_splits_from_document(
document=document,
text_splitter=text_splitter,
include_extra_info=include_extra_info,
)
nodes: List[Node] = []
index_counter = 0
for i, text_split in enumerate(text_splits):
text_chunk = text_split.text_chunk
logger.debug(f"> Adding chunk: {truncate_text(text_chunk, 50)}")
index_pos_info = None
if text_split.num_char_overlap is not None:
index_pos_info = {
# NOTE: start is inclusive, end is exclusive
"start": index_counter - text_split.num_char_overlap,
"end": index_counter - text_split.num_char_overlap + len(text_chunk),
}
index_counter += len(text_chunk) + 1
if isinstance(document, ImageDocument):
image_node = ImageNode(
text=text_chunk,
embedding=document.embedding,
extra_info=document.extra_info if include_extra_info else None,
node_info=index_pos_info,
image=document.image,
relationships={DocumentRelationship.SOURCE: document.get_doc_id()},
)
nodes.append(image_node) # type: ignore
else:
node = Node(
text=text_chunk,
embedding=document.embedding,
extra_info=document.extra_info if include_extra_info else None,
node_info=index_pos_info,
relationships={DocumentRelationship.SOURCE: document.get_doc_id()},
)
nodes.append(node)
# if include_prev_next_rel, then add prev/next relationships
if include_prev_next_rel:
for i, node in enumerate(nodes):
if i > 0:
node.relationships[DocumentRelationship.PREVIOUS] = nodes[
i - 1
].get_doc_id()
if i < len(nodes) - 1:
node.relationships[DocumentRelationship.NEXT] = nodes[
i + 1
].get_doc_id()
return nodes