File size: 7,368 Bytes
693faa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""

Minimal (byte-level) Byte Pair Encoding tokenizer.



Algorithmically follows along the GPT tokenizer:

https://github.com/openai/gpt-2/blob/master/src/encoder.py



But:

- Does not handle the regular expression splitting pattern.

- Does not handle any special tokens.

"""
import copy

from .base import Tokenizer, get_stats, merge


# class BasicTokenizer(Tokenizer):
#
#     def __init__(self):
#         super().__init__()
#
#     def train(self, text, vocab_size, verbose=False):
#         assert vocab_size >= 256
#         num_merges = vocab_size - 256
#
#         # input text preprocessing
#         text_bytes = text.encode("utf-8")  # raw bytes
#         ids = list(text_bytes)  # list of integers in range 0..255
#
#         # iteratively merge the most common pairs to create new tokens
#         merges = {}  # (int, int) -> int
#         vocab = {idx: bytes([idx]) for idx in range(256)}  # int -> bytes
#         for i in range(num_merges):
#             # count up the number of times every consecutive pair appears
#             stats = get_stats(ids)
#             # find the pair with the highest count
#             pair = max(stats, key=stats.get)
#             # mint a new token: assign it the next available id
#             idx = 256 + i
#             # replace all occurrences of pair in ids with idx
#             ids = merge(ids, pair, idx)
#             # save the merge
#             merges[pair] = idx
#             vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
#             # prints
#             if verbose:
#                 print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
#
#         # save class variables
#         self.merges = merges  # used in encode()
#         self.vocab = vocab  # used in decode()
#
#     def decode(self, ids):
#         # given ids (list of integers), return Python string
#         text_bytes = b"".join(self.vocab[idx] for idx in ids)
#         text = text_bytes.decode("utf-8", errors="replace")
#         return text
#
#     def encode(self, text):
#         # given a string text, return the token ids
#         text_bytes = text.encode("utf-8")  # raw bytes
#         ids = list(text_bytes)  # list of integers in range 0..255
#         while len(ids) >= 2:
#             # find the pair with the lowest merge index
#             stats = get_stats(ids)
#             pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
#             # subtle: if there are no more merges available, the key will
#             # result in an inf for every single pair, and the min will be
#             # just the first pair in the list, arbitrarily
#             # we can detect this terminating case by a membership check
#             if pair not in self.merges:
#                 break  # nothing else can be merged anymore
#             # otherwise let's merge the best pair (lowest merge index)
#             idx = self.merges[pair]
#             ids = merge(ids, pair, idx)
#         return ids


class BasicTokenizer(Tokenizer):

    def __init__(self):
        super().__init__()
        self.merge_counter = 0

    def train(self, text, vocab_size, verbose=False):
        # left assert in place just to introduce consistency and a hard check of the increase in vocab size and number of merges
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        current_batch_merge_counter = 0  # in case not all exact `num_merges` happen

        # input text preprocessing
        text_bytes = text.encode("utf-8")  # encode to get all waw bytes
        ids = list(text_bytes)  # represent the bytes in ints

        # use same merge dict if exists
        self.merges = {} if self.merges is None else self.merges  # to hold all merges (int, int) -> int

        # Use same vocab for this Tokenizer object if it exists
        # Tokenizer vocab:  int -> bytes
        self.vocab = {idx: bytes([idx]) for idx in range(256)} if self.vocab is None else self.vocab

        # iteratively merge the MOST COMMON pair from the text
        for i in range(num_merges):
            # get count of pairs
            stats = get_stats(ids)

            # find the pair with the highest count
            # pair = max(stats, key=stats.get)

            # tmp_stats = copy.deepcopy(stats)

            # get most occurring pair from ids
            pair = max(stats, key=stats.get)

            while pair in self.merges:
                # pair was previously merged ... use this first to update IDS
                # No need to add to merges and vocab, use previously stored token
                already_merged_idx = self.merges[pair]

                # just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab
                ids = merge(ids, pair, already_merged_idx)

                stats = get_stats(ids)

                if stats and len(ids) >= 2:
                    pair = max(stats, key=stats.get)
                else:
                    # no new merges found in this incoming data batch
                    print(f"\n\nstopping merges as no new byte pair found in the current batch")
                    break

            # this most occurring pair not merged yet in any data batch
            #  generate a new token considering how many have been generated so far for the same tokenizer
            idx = len(self.vocab) + 1

            # update current new generated tokens to add to self.merge_counter later
            current_batch_merge_counter += 1

            # replace all occurrences of `pair` above in `ids` with NEW `idx` token, add this one to merges & vocab
            # Note: this pair has never been seen for merging
            ids = merge(ids, pair, idx)
            self.merges[pair] = idx
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
            if verbose:
                print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {stats[pair]} count")

        self.merge_counter += current_batch_merge_counter

    def decode(self, ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def encode(self, text):
        # input a string text, returns the token ids
        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)
        while len(ids) >= 2:
            # here find the pair with the lowest merge index
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            # if no merges i.e. the pair is not in merges dict,
            # the key will result in an `inf` for every single pair,
            # and the min will be just the first pair in the list,
            # we can detect this terminating case by a membership check
            if pair not in self.merges:
                break  # nothing else can be merged anymore
            # otherwise merge the best pair NOTE: (lowest merge index)
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids