AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
raw
history blame
9.82 kB
"""Text splitter implementations."""
from dataclasses import dataclass
from typing import Callable, List, Optional
from langchain.text_splitter import TextSplitter
from gpt_index.utils import globals_helper
@dataclass
class TextSplit:
"""Text split with overlap.
Attributes:
text_chunk: The text string.
num_char_overlap: The number of overlapping characters with the previous chunk.
"""
text_chunk: str
num_char_overlap: Optional[int] = None
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at word tokens."""
def __init__(
self,
separator: str = " ",
chunk_size: int = 4000,
chunk_overlap: int = 200,
tokenizer: Optional[Callable] = None,
backup_separators: Optional[List[str]] = ["\n"],
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._separator = separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self.tokenizer = tokenizer or globals_helper.tokenizer
self._backup_separators = backup_separators
def _reduce_chunk_size(
self, start_idx: int, cur_idx: int, splits: List[str]
) -> int:
"""Reduce the chunk size by reducing cur_idx.
Return the new cur_idx.
"""
current_doc_total = len(
self.tokenizer(self._separator.join(splits[start_idx:cur_idx]))
)
while current_doc_total > self._chunk_size:
percent_to_reduce = (
current_doc_total - self._chunk_size
) / current_doc_total
num_to_reduce = int(percent_to_reduce * (cur_idx - start_idx)) + 1
cur_idx -= num_to_reduce
current_doc_total = len(
self.tokenizer(self._separator.join(splits[start_idx:cur_idx]))
)
return cur_idx
def _preprocess_splits(self, splits: List[str], chunk_size: int) -> List[str]:
"""Process splits.
Specifically search for tokens that are too large for chunk size,
and see if we can separate those tokens more
(via backup separators if specified, or force chunking).
"""
new_splits = []
for split in splits:
num_cur_tokens = len(self.tokenizer(split))
if num_cur_tokens <= chunk_size:
new_splits.append(split)
else:
cur_splits = [split]
if self._backup_separators:
for sep in self._backup_separators:
if sep in split:
cur_splits = split.split(sep)
break
else:
cur_splits = [split]
cur_splits2 = []
for cur_split in cur_splits:
num_cur_tokens = len(self.tokenizer(cur_split))
if num_cur_tokens <= chunk_size:
cur_splits2.extend([cur_split])
else:
cur_split_chunks = [
cur_split[i : i + chunk_size]
for i in range(0, len(cur_split), chunk_size)
]
cur_splits2.extend(cur_split_chunks)
new_splits.extend(cur_splits2)
return new_splits
def _postprocess_splits(self, docs: List[TextSplit]) -> List[TextSplit]:
"""Post-process splits."""
# TODO: prune text splits, remove empty spaces
new_docs = []
for doc in docs:
if doc.text_chunk.replace(" ", "") == "":
continue
new_docs.append(doc)
return new_docs
def split_text(self, text: str, extra_info_str: Optional[str] = None) -> List[str]:
"""Split incoming text and return chunks."""
text_splits = self.split_text_with_overlaps(text, extra_info_str=extra_info_str)
return [text_split.text_chunk for text_split in text_splits]
def split_text_with_overlaps(
self, text: str, extra_info_str: Optional[str] = None
) -> List[TextSplit]:
"""Split incoming text and return chunks with overlap size."""
if text == "":
return []
# NOTE: Consider extra info str that will be added to the chunk at query time
# This reduces the effective chunk size that we can have
if extra_info_str is not None:
# NOTE: extra 2 newline chars for formatting when prepending in query
num_extra_tokens = len(self.tokenizer(f"{extra_info_str}\n\n")) + 1
effective_chunk_size = self._chunk_size - num_extra_tokens
if effective_chunk_size <= 0:
raise ValueError(
"Effective chunk size is non positive after considering extra_info"
)
else:
effective_chunk_size = self._chunk_size
# First we naively split the large input into a bunch of smaller ones.
splits = text.split(self._separator)
splits = self._preprocess_splits(splits, effective_chunk_size)
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
docs: List[TextSplit] = []
start_idx = 0
cur_idx = 0
cur_total = 0
prev_idx = 0 # store the previous end index
while cur_idx < len(splits):
cur_token = splits[cur_idx]
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
if num_cur_tokens > effective_chunk_size:
raise ValueError(
"A single term is larger than the allowed chunk size.\n"
f"Term size: {num_cur_tokens}\n"
f"Chunk size: {self._chunk_size}"
f"Effective chunk size: {effective_chunk_size}"
)
# If adding token to current_doc would exceed the chunk size:
# 1. First verify with tokenizer that current_doc
# 1. Update the docs list
if cur_total + num_cur_tokens > effective_chunk_size:
# NOTE: since we use a proxy for counting tokens, we want to
# run tokenizer across all of current_doc first. If
# the chunk is too big, then we will reduce text in pieces
cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits)
overlap = 0
# after first round, check if last chunk ended after this chunk begins
if prev_idx > 0 and prev_idx > start_idx:
overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)])
docs.append(
TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap)
)
prev_idx = cur_idx
# 2. Shrink the current_doc (from the front) until it is gets smaller
# than the overlap size
# NOTE: because counting tokens individually is an imperfect
# proxy (but much faster proxy) for the total number of tokens consumed,
# we need to enforce that start_idx <= cur_idx, otherwise
# start_idx has a chance of going out of bounds.
while cur_total > self._chunk_overlap and start_idx < cur_idx:
# # call tokenizer on entire overlap
# cur_total = self.tokenizer()
cur_num_tokens = max(len(self.tokenizer(splits[start_idx])), 1)
cur_total -= cur_num_tokens
start_idx += 1
# NOTE: This is a hack, make more general
if start_idx == cur_idx:
cur_total = 0
# Build up the current_doc with term d, and update the total counter with
# the number of the number of tokens in d, wrt self.tokenizer
# we reassign cur_token and num_cur_tokens, because cur_idx
# may have changed
cur_token = splits[cur_idx]
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
cur_total += num_cur_tokens
cur_idx += 1
overlap = 0
# after first round, check if last chunk ended after this chunk begins
if prev_idx > start_idx:
overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)]) + len(
range(start_idx, prev_idx)
)
docs.append(TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap))
# run postprocessing to remove blank spaces
docs = self._postprocess_splits(docs)
return docs
def truncate_text(self, text: str) -> str:
"""Truncate text in order to fit the underlying chunk size."""
if text == "":
return ""
# First we naively split the large input into a bunch of smaller ones.
splits = text.split(self._separator)
splits = self._preprocess_splits(splits, self._chunk_size)
start_idx = 0
cur_idx = 0
cur_total = 0
while cur_idx < len(splits):
cur_token = splits[cur_idx]
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
if cur_total + num_cur_tokens > self._chunk_size:
cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits)
break
cur_total += num_cur_tokens
cur_idx += 1
return self._separator.join(splits[start_idx:cur_idx])
__all__ = ["TextSplitter", "TokenTextSplitter"]