File size: 9,822 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""Text splitter implementations."""
from dataclasses import dataclass
from typing import Callable, List, Optional

from langchain.text_splitter import TextSplitter

from gpt_index.utils import globals_helper


@dataclass
class TextSplit:
    """Text split with overlap.

    Attributes:
        text_chunk: The text string.
        num_char_overlap: The number of overlapping characters with the previous chunk.
    """

    text_chunk: str
    num_char_overlap: Optional[int] = None


class TokenTextSplitter(TextSplitter):
    """Implementation of splitting text that looks at word tokens."""

    def __init__(
        self,
        separator: str = " ",
        chunk_size: int = 4000,
        chunk_overlap: int = 200,
        tokenizer: Optional[Callable] = None,
        backup_separators: Optional[List[str]] = ["\n"],
    ):
        """Initialize with parameters."""
        if chunk_overlap > chunk_size:
            raise ValueError(
                f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
                f"({chunk_size}), should be smaller."
            )
        self._separator = separator
        self._chunk_size = chunk_size
        self._chunk_overlap = chunk_overlap
        self.tokenizer = tokenizer or globals_helper.tokenizer
        self._backup_separators = backup_separators

    def _reduce_chunk_size(
        self, start_idx: int, cur_idx: int, splits: List[str]
    ) -> int:
        """Reduce the chunk size by reducing cur_idx.

        Return the new cur_idx.

        """
        current_doc_total = len(
            self.tokenizer(self._separator.join(splits[start_idx:cur_idx]))
        )
        while current_doc_total > self._chunk_size:
            percent_to_reduce = (
                current_doc_total - self._chunk_size
            ) / current_doc_total
            num_to_reduce = int(percent_to_reduce * (cur_idx - start_idx)) + 1
            cur_idx -= num_to_reduce
            current_doc_total = len(
                self.tokenizer(self._separator.join(splits[start_idx:cur_idx]))
            )
        return cur_idx

    def _preprocess_splits(self, splits: List[str], chunk_size: int) -> List[str]:
        """Process splits.

        Specifically search for tokens that are too large for chunk size,
        and see if we can separate those tokens more
        (via backup separators if specified, or force chunking).

        """
        new_splits = []
        for split in splits:
            num_cur_tokens = len(self.tokenizer(split))
            if num_cur_tokens <= chunk_size:
                new_splits.append(split)
            else:
                cur_splits = [split]
                if self._backup_separators:
                    for sep in self._backup_separators:
                        if sep in split:
                            cur_splits = split.split(sep)
                            break
                else:
                    cur_splits = [split]

                cur_splits2 = []
                for cur_split in cur_splits:
                    num_cur_tokens = len(self.tokenizer(cur_split))
                    if num_cur_tokens <= chunk_size:
                        cur_splits2.extend([cur_split])
                    else:
                        cur_split_chunks = [
                            cur_split[i : i + chunk_size]
                            for i in range(0, len(cur_split), chunk_size)
                        ]
                        cur_splits2.extend(cur_split_chunks)

                new_splits.extend(cur_splits2)
        return new_splits

    def _postprocess_splits(self, docs: List[TextSplit]) -> List[TextSplit]:
        """Post-process splits."""
        # TODO: prune text splits, remove empty spaces
        new_docs = []
        for doc in docs:
            if doc.text_chunk.replace(" ", "") == "":
                continue
            new_docs.append(doc)
        return new_docs

    def split_text(self, text: str, extra_info_str: Optional[str] = None) -> List[str]:
        """Split incoming text and return chunks."""
        text_splits = self.split_text_with_overlaps(text, extra_info_str=extra_info_str)
        return [text_split.text_chunk for text_split in text_splits]

    def split_text_with_overlaps(
        self, text: str, extra_info_str: Optional[str] = None
    ) -> List[TextSplit]:
        """Split incoming text and return chunks with overlap size."""
        if text == "":
            return []

        # NOTE: Consider extra info str that will be added to the chunk at query time
        #       This reduces the effective chunk size that we can have
        if extra_info_str is not None:
            # NOTE: extra 2 newline chars for formatting when prepending in query
            num_extra_tokens = len(self.tokenizer(f"{extra_info_str}\n\n")) + 1
            effective_chunk_size = self._chunk_size - num_extra_tokens

            if effective_chunk_size <= 0:
                raise ValueError(
                    "Effective chunk size is non positive after considering extra_info"
                )
        else:
            effective_chunk_size = self._chunk_size

        # First we naively split the large input into a bunch of smaller ones.
        splits = text.split(self._separator)
        splits = self._preprocess_splits(splits, effective_chunk_size)
        # We now want to combine these smaller pieces into medium size
        # chunks to send to the LLM.
        docs: List[TextSplit] = []

        start_idx = 0
        cur_idx = 0
        cur_total = 0
        prev_idx = 0  # store the previous end index
        while cur_idx < len(splits):
            cur_token = splits[cur_idx]
            num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
            if num_cur_tokens > effective_chunk_size:
                raise ValueError(
                    "A single term is larger than the allowed chunk size.\n"
                    f"Term size: {num_cur_tokens}\n"
                    f"Chunk size: {self._chunk_size}"
                    f"Effective chunk size: {effective_chunk_size}"
                )
            # If adding token to current_doc would exceed the chunk size:
            # 1. First verify with tokenizer that current_doc
            # 1. Update the docs list
            if cur_total + num_cur_tokens > effective_chunk_size:
                # NOTE: since we use a proxy for counting tokens, we want to
                # run tokenizer across all of current_doc first. If
                # the chunk is too big, then we will reduce text in pieces
                cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits)
                overlap = 0
                # after first round, check if last chunk ended after this chunk begins
                if prev_idx > 0 and prev_idx > start_idx:
                    overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)])

                docs.append(
                    TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap)
                )
                prev_idx = cur_idx
                # 2. Shrink the current_doc (from the front) until it is gets smaller
                # than the overlap size
                # NOTE: because counting tokens individually is an imperfect
                # proxy (but much faster proxy) for the total number of tokens consumed,
                # we need to enforce that start_idx <= cur_idx, otherwise
                # start_idx has a chance of going out of bounds.
                while cur_total > self._chunk_overlap and start_idx < cur_idx:
                    # # call tokenizer on entire overlap
                    # cur_total = self.tokenizer()
                    cur_num_tokens = max(len(self.tokenizer(splits[start_idx])), 1)
                    cur_total -= cur_num_tokens
                    start_idx += 1
                # NOTE: This is a hack, make more general
                if start_idx == cur_idx:
                    cur_total = 0
            # Build up the current_doc with term d, and update the total counter with
            # the number of the number of tokens in d, wrt self.tokenizer

            # we reassign cur_token and num_cur_tokens, because cur_idx
            # may have changed
            cur_token = splits[cur_idx]
            num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)

            cur_total += num_cur_tokens
            cur_idx += 1
        overlap = 0
        # after first round, check if last chunk ended after this chunk begins
        if prev_idx > start_idx:
            overlap = sum([len(splits[i]) for i in range(start_idx, prev_idx)]) + len(
                range(start_idx, prev_idx)
            )
        docs.append(TextSplit(self._separator.join(splits[start_idx:cur_idx]), overlap))

        # run postprocessing to remove blank spaces
        docs = self._postprocess_splits(docs)
        return docs

    def truncate_text(self, text: str) -> str:
        """Truncate text in order to fit the underlying chunk size."""
        if text == "":
            return ""
        # First we naively split the large input into a bunch of smaller ones.
        splits = text.split(self._separator)
        splits = self._preprocess_splits(splits, self._chunk_size)

        start_idx = 0
        cur_idx = 0
        cur_total = 0
        while cur_idx < len(splits):
            cur_token = splits[cur_idx]
            num_cur_tokens = max(len(self.tokenizer(cur_token)), 1)
            if cur_total + num_cur_tokens > self._chunk_size:
                cur_idx = self._reduce_chunk_size(start_idx, cur_idx, splits)
                break
            cur_total += num_cur_tokens
            cur_idx += 1
        return self._separator.join(splits[start_idx:cur_idx])


__all__ = ["TextSplitter", "TokenTextSplitter"]