Spaces:
Runtime error
Runtime error
"""Text splitter implementations.""" | |
from dataclasses import dataclass | |
from typing import Callable, List, Optional | |
from langchain.text_splitter import TextSplitter | |
from gpt_index.utils import globals_helper | |
class TextSplit: | |
"""Text split with overlap. | |
Attributes: | |
text_chunk: The text string. | |
num_char_overlap: The number of overlapping characters with the previous chunk. | |
""" | |
text_chunk: str | |
num_char_overlap: Optional[int] = None | |
class TokenTextSplitter(TextSplitter): | |
"""Implementation of splitting text that looks at word tokens.""" | |
def __init__( | |
self, | |
separator: str = " ", | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
tokenizer: Optional[Callable] = None, | |
backup_separators: Optional[List[str]] = ["\n"], | |
): | |
"""Initialize with parameters.""" | |
if chunk_overlap > chunk_size: | |
raise ValueError( | |
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
f"({chunk_size}), should be smaller." | |
) | |
self._separator = separator | |
self._chunk_size = chunk_size | |
self._chunk_overlap = chunk_overlap | |
self.tokenizer = tokenizer or globals_helper.tokenizer | |
self._backup_separators = backup_separators | |
def _reduce_chunk_size( | |
self, start_idx: int, cur_idx: int, splits: List[str] | |
) -> int: | |
"""Reduce the chunk size by reducing cur_idx. | |
Return the new cur_idx. | |
""" | |
current_doc_total = len( | |
self.tokenizer(self._separator.join(splits[start_idx:cur_idx])) | |
) | |
while current_doc_total > self._chunk_size: | |
percent_to_reduce = ( | |
current_doc_total - self._chunk_size | |
) / current_doc_total | |
num_to_reduce = int(percent_to_reduce * (cur_idx - start_idx)) + 1 | |
cur_idx -= num_to_reduce | |
current_doc_total = len( | |
self.tokenizer(self._separator.join(splits[start_idx:cur_idx])) | |
) | |
return cur_idx | |
def _preprocess_splits(self, splits: List[str], chunk_size: int) -> List[str]: | |
"""Process splits. | |
Specifically search for tokens that are too large for chunk size, | |
and see if we can separate those tokens more | |
(via backup separators if specified, or force chunking). | |
""" | |
new_splits = [] | |
for split in splits: | |
num_cur_tokens = len(self.tokenizer(split)) | |
if num_cur_tokens <= chunk_size: | |
new_splits.append(split) | |
else: | |
cur_splits = [split] | |
if self._backup_separators: | |
for sep in self._backup_separators: | |
if sep in split: | |
cur_splits = split.split(sep) | |
break | |
else: | |
cur_splits = [split] | |
cur_splits2 = [] | |
for cur_split in cur_splits: | |
num_cur_tokens = len(self.tokenizer(cur_split)) | |
if num_cur_tokens <= chunk_size: | |
cur_splits2.extend([cur_split]) | |
else: | |
cur_split_chunks = [ | |
cur_split[i : i + chunk_size] | |
for i in range(0, len(cur_split), chunk_size) | |
] | |
cur_splits2.extend(cur_split_chunks) | |
new_splits.extend(cur_splits2) | |
return new_splits | |
def _postprocess_splits(self, docs: List[TextSplit]) -> List[TextSplit]: | |
"""Post-process splits.""" | |
# TODO: prune text splits, remove empty spaces | |
new_docs = [] | |
for doc in docs: | |
if doc.text_chunk.replace(" ", "") == "": | |
continue | |
new_docs.append(doc) | |
return new_docs | |
def split_text(self, text: str, extra_info_str: Optional[str] = None) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
text_splits = self.split_text_with_overlaps(text, extra_info_str=extra_info_str) | |
return [text_split.text_chunk for text_split in text_splits] | |
def split_text_with_overlaps( | |
self, text: str, extra_info_str: Optional[str] = None | |
) -> List[TextSplit]: | |
"""Split incoming text and return chunks with overlap size.""" | |
if text == "": | |
return [] | |
# NOTE: Consider extra info str that will be added to the chunk at query time | |
# This reduces the effective chunk size that we can have | |
if extra_info_str is not None: | |
# NOTE: extra 2 newline chars for formatting when prepending in query | |
num_extra_tokens = len(self.tokenizer(f"{extra_info_str}\n\n")) + 1 | |
effective_chunk_size = self._chunk_size - num_extra_tokens | |
if effective_chunk_size <= 0: | |
raise ValueError( | |
"Effective chunk size is non positive after considering extra_info" | |
) | |
else: | |
effective_chunk_size = self._chunk_size | |
# First we naively split the large input into a bunch of smaller ones. | |
splits = text.split(self._separator) | |
splits = self._preprocess_splits(splits, effective_chunk_size) | |
# We now want to combine these smaller pieces into medium size | |
# chunks to send to the LLM. | |
docs: List[TextSplit] = [] | |
start_idx = 0 | |
cur_idx = 0 | |
cur_total = 0 | |
prev_idx = 0 # store the previous end index | |
while cur_idx < len(splits): | |
cur_token = splits[cur_idx] | |
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1) | |
if num_cur_tokens > effective_chunk_size: | |
raise ValueError( | |
"A single term is larger than the allowed chunk size.\n" | |
f"Term size: {num_cur_tokens}\n" | |
f"Chunk size: {self._chunk_size}" | |
f"Effective chunk size: {effective_chunk_size}" | |
) | |
# If adding token to current_doc would exceed the chunk size: | |
# 1. First verify with tokenizer that current_doc | |
# 1. Update the docs list | |
if cur_total + num_cur_tokens > effective_chunk_size: | |
# NOTE: since we use a proxy for counting tokens, we want to | |
# run tokenizer across all of current_doc first. If | |
# the chunk is too big, then we will reduce text in pieces | |
cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits) | |
overlap = 0 | |
# after first round, check if last chunk ended after this chunk begins | |
if prev_idx > 0 and prev_idx > start_idx: | |
overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)]) | |
docs.append( | |
TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap) | |
) | |
prev_idx = cur_idx | |
# 2. Shrink the current_doc (from the front) until it is gets smaller | |
# than the overlap size | |
# NOTE: because counting tokens individually is an imperfect | |
# proxy (but much faster proxy) for the total number of tokens consumed, | |
# we need to enforce that start_idx <= cur_idx, otherwise | |
# start_idx has a chance of going out of bounds. | |
while cur_total > self._chunk_overlap and start_idx < cur_idx: | |
# # call tokenizer on entire overlap | |
# cur_total = self.tokenizer() | |
cur_num_tokens = max(len(self.tokenizer(splits[start_idx])), 1) | |
cur_total -= cur_num_tokens | |
start_idx += 1 | |
# NOTE: This is a hack, make more general | |
if start_idx == cur_idx: | |
cur_total = 0 | |
# Build up the current_doc with term d, and update the total counter with | |
# the number of the number of tokens in d, wrt self.tokenizer | |
# we reassign cur_token and num_cur_tokens, because cur_idx | |
# may have changed | |
cur_token = splits[cur_idx] | |
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1) | |
cur_total += num_cur_tokens | |
cur_idx += 1 | |
overlap = 0 | |
# after first round, check if last chunk ended after this chunk begins | |
if prev_idx > start_idx: | |
overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)]) + len( | |
range(start_idx, prev_idx) | |
) | |
docs.append(TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap)) | |
# run postprocessing to remove blank spaces | |
docs = self._postprocess_splits(docs) | |
return docs | |
def truncate_text(self, text: str) -> str: | |
"""Truncate text in order to fit the underlying chunk size.""" | |
if text == "": | |
return "" | |
# First we naively split the large input into a bunch of smaller ones. | |
splits = text.split(self._separator) | |
splits = self._preprocess_splits(splits, self._chunk_size) | |
start_idx = 0 | |
cur_idx = 0 | |
cur_total = 0 | |
while cur_idx < len(splits): | |
cur_token = splits[cur_idx] | |
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1) | |
if cur_total + num_cur_tokens > self._chunk_size: | |
cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits) | |
break | |
cur_total += num_cur_tokens | |
cur_idx += 1 | |
return self._separator.join(splits[start_idx:cur_idx]) | |
__all__ = ["TextSplitter", "TokenTextSplitter"] | |