Spaces:
Runtime error
Runtime error
File size: 9,822 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
"""Text splitter implementations."""
from dataclasses import dataclass
from typing import Callable, List, Optional
from langchain.text_splitter import TextSplitter
from gpt_index.utils import globals_helper
@dataclass
class TextSplit:
"""Text split with overlap.
Attributes:
text_chunk: The text string.
num_char_overlap: The number of overlapping characters with the previous chunk.
"""
text_chunk: str
num_char_overlap: Optional[int] = None
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at word tokens."""
def __init__(
self,
separator: str = " ",
chunk_size: int = 4000,
chunk_overlap: int = 200,
tokenizer: Optional[Callable] = None,
backup_separators: Optional[List[str]] = ["\n"],
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._separator = separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self.tokenizer = tokenizer or globals_helper.tokenizer
self._backup_separators = backup_separators
def _reduce_chunk_size(
self, start_idx: int, cur_idx: int, splits: List[str]
) -> int:
"""Reduce the chunk size by reducing cur_idx.
Return the new cur_idx.
"""
current_doc_total = len(
self.tokenizer(self._separator.join(splits[start_idx:cur_idx]))
)
while current_doc_total > self._chunk_size:
percent_to_reduce = (
current_doc_total - self._chunk_size
) / current_doc_total
num_to_reduce = int(percent_to_reduce * (cur_idx - start_idx)) + 1
cur_idx -= num_to_reduce
current_doc_total = len(
self.tokenizer(self._separator.join(splits[start_idx:cur_idx]))
)
return cur_idx
def _preprocess_splits(self, splits: List[str], chunk_size: int) -> List[str]:
"""Process splits.
Specifically search for tokens that are too large for chunk size,
and see if we can separate those tokens more
(via backup separators if specified, or force chunking).
"""
new_splits = []
for split in splits:
num_cur_tokens = len(self.tokenizer(split))
if num_cur_tokens <= chunk_size:
new_splits.append(split)
else:
cur_splits = [split]
if self._backup_separators:
for sep in self._backup_separators:
if sep in split:
cur_splits = split.split(sep)
break
else:
cur_splits = [split]
cur_splits2 = []
for cur_split in cur_splits:
num_cur_tokens = len(self.tokenizer(cur_split))
if num_cur_tokens <= chunk_size:
cur_splits2.extend([cur_split])
else:
cur_split_chunks = [
cur_split[i : i + chunk_size]
for i in range(0, len(cur_split), chunk_size)
]
cur_splits2.extend(cur_split_chunks)
new_splits.extend(cur_splits2)
return new_splits
def _postprocess_splits(self, docs: List[TextSplit]) -> List[TextSplit]:
"""Post-process splits."""
# TODO: prune text splits, remove empty spaces
new_docs = []
for doc in docs:
if doc.text_chunk.replace(" ", "") == "":
continue
new_docs.append(doc)
return new_docs
def split_text(self, text: str, extra_info_str: Optional[str] = None) -> List[str]:
"""Split incoming text and return chunks."""
text_splits = self.split_text_with_overlaps(text, extra_info_str=extra_info_str)
return [text_split.text_chunk for text_split in text_splits]
def split_text_with_overlaps(
self, text: str, extra_info_str: Optional[str] = None
) -> List[TextSplit]:
"""Split incoming text and return chunks with overlap size."""
if text == "":
return []
# NOTE: Consider extra info str that will be added to the chunk at query time
# This reduces the effective chunk size that we can have
if extra_info_str is not None:
# NOTE: extra 2 newline chars for formatting when prepending in query
num_extra_tokens = len(self.tokenizer(f"{extra_info_str}\n\n")) + 1
effective_chunk_size = self._chunk_size - num_extra_tokens
if effective_chunk_size <= 0:
raise ValueError(
"Effective chunk size is non positive after considering extra_info"
)
else:
effective_chunk_size = self._chunk_size
# First we naively split the large input into a bunch of smaller ones.
splits = text.split(self._separator)
splits = self._preprocess_splits(splits, effective_chunk_size)
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
docs: List[TextSplit] = []
start_idx = 0
cur_idx = 0
cur_total = 0
prev_idx = 0 # store the previous end index
while cur_idx < len(splits):
cur_token = splits[cur_idx]
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
if num_cur_tokens > effective_chunk_size:
raise ValueError(
"A single term is larger than the allowed chunk size.\n"
f"Term size: {num_cur_tokens}\n"
f"Chunk size: {self._chunk_size}"
f"Effective chunk size: {effective_chunk_size}"
)
# If adding token to current_doc would exceed the chunk size:
# 1. First verify with tokenizer that current_doc
# 1. Update the docs list
if cur_total + num_cur_tokens > effective_chunk_size:
# NOTE: since we use a proxy for counting tokens, we want to
# run tokenizer across all of current_doc first. If
# the chunk is too big, then we will reduce text in pieces
cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits)
overlap = 0
# after first round, check if last chunk ended after this chunk begins
if prev_idx > 0 and prev_idx > start_idx:
overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)])
docs.append(
TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap)
)
prev_idx = cur_idx
# 2. Shrink the current_doc (from the front) until it is gets smaller
# than the overlap size
# NOTE: because counting tokens individually is an imperfect
# proxy (but much faster proxy) for the total number of tokens consumed,
# we need to enforce that start_idx <= cur_idx, otherwise
# start_idx has a chance of going out of bounds.
while cur_total > self._chunk_overlap and start_idx < cur_idx:
# # call tokenizer on entire overlap
# cur_total = self.tokenizer()
cur_num_tokens = max(len(self.tokenizer(splits[start_idx])), 1)
cur_total -= cur_num_tokens
start_idx += 1
# NOTE: This is a hack, make more general
if start_idx == cur_idx:
cur_total = 0
# Build up the current_doc with term d, and update the total counter with
# the number of the number of tokens in d, wrt self.tokenizer
# we reassign cur_token and num_cur_tokens, because cur_idx
# may have changed
cur_token = splits[cur_idx]
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
cur_total += num_cur_tokens
cur_idx += 1
overlap = 0
# after first round, check if last chunk ended after this chunk begins
if prev_idx > start_idx:
overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)]) + len(
range(start_idx, prev_idx)
)
docs.append(TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap))
# run postprocessing to remove blank spaces
docs = self._postprocess_splits(docs)
return docs
def truncate_text(self, text: str) -> str:
"""Truncate text in order to fit the underlying chunk size."""
if text == "":
return ""
# First we naively split the large input into a bunch of smaller ones.
splits = text.split(self._separator)
splits = self._preprocess_splits(splits, self._chunk_size)
start_idx = 0
cur_idx = 0
cur_total = 0
while cur_idx < len(splits):
cur_token = splits[cur_idx]
num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
if cur_total + num_cur_tokens > self._chunk_size:
cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits)
break
cur_total += num_cur_tokens
cur_idx += 1
return self._separator.join(splits[start_idx:cur_idx])
__all__ = ["TextSplitter", "TokenTextSplitter"]
|