Spaces:
Sleeping
Sleeping
File size: 7,368 Bytes
693faa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.
Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
But:
- Does not handle the regular expression splitting pattern.
- Does not handle any special tokens.
"""
import copy
from .base import Tokenizer, get_stats, merge
# class BasicTokenizer(Tokenizer):
#
# def __init__(self):
# super().__init__()
#
# def train(self, text, vocab_size, verbose=False):
# assert vocab_size >= 256
# num_merges = vocab_size - 256
#
# # input text preprocessing
# text_bytes = text.encode("utf-8") # raw bytes
# ids = list(text_bytes) # list of integers in range 0..255
#
# # iteratively merge the most common pairs to create new tokens
# merges = {} # (int, int) -> int
# vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
# for i in range(num_merges):
# # count up the number of times every consecutive pair appears
# stats = get_stats(ids)
# # find the pair with the highest count
# pair = max(stats, key=stats.get)
# # mint a new token: assign it the next available id
# idx = 256 + i
# # replace all occurrences of pair in ids with idx
# ids = merge(ids, pair, idx)
# # save the merge
# merges[pair] = idx
# vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# # prints
# if verbose:
# print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
#
# # save class variables
# self.merges = merges # used in encode()
# self.vocab = vocab # used in decode()
#
# def decode(self, ids):
# # given ids (list of integers), return Python string
# text_bytes = b"".join(self.vocab[idx] for idx in ids)
# text = text_bytes.decode("utf-8", errors="replace")
# return text
#
# def encode(self, text):
# # given a string text, return the token ids
# text_bytes = text.encode("utf-8") # raw bytes
# ids = list(text_bytes) # list of integers in range 0..255
# while len(ids) >= 2:
# # find the pair with the lowest merge index
# stats = get_stats(ids)
# pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# # subtle: if there are no more merges available, the key will
# # result in an inf for every single pair, and the min will be
# # just the first pair in the list, arbitrarily
# # we can detect this terminating case by a membership check
# if pair not in self.merges:
# break # nothing else can be merged anymore
# # otherwise let's merge the best pair (lowest merge index)
# idx = self.merges[pair]
# ids = merge(ids, pair, idx)
# return ids
class BasicTokenizer(Tokenizer):
def __init__(self):
super().__init__()
self.merge_counter = 0
def train(self, text, vocab_size, verbose=False):
# left assert in place just to introduce consistency and a hard check of the increase in vocab size and number of merges
assert vocab_size >= 256
num_merges = vocab_size - 256
current_batch_merge_counter = 0 # in case not all exact `num_merges` happen
# input text preprocessing
text_bytes = text.encode("utf-8") # encode to get all waw bytes
ids = list(text_bytes) # represent the bytes in ints
# use same merge dict if exists
self.merges = {} if self.merges is None else self.merges # to hold all merges (int, int) -> int
# Use same vocab for this Tokenizer object if it exists
# Tokenizer vocab: int -> bytes
self.vocab = {idx: bytes([idx]) for idx in range(256)} if self.vocab is None else self.vocab
# iteratively merge the MOST COMMON pair from the text
for i in range(num_merges):
# get count of pairs
stats = get_stats(ids)
# find the pair with the highest count
# pair = max(stats, key=stats.get)
# tmp_stats = copy.deepcopy(stats)
# get most occurring pair from ids
pair = max(stats, key=stats.get)
while pair in self.merges:
# pair was previously merged ... use this first to update IDS
# No need to add to merges and vocab, use previously stored token
already_merged_idx = self.merges[pair]
# just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab
ids = merge(ids, pair, already_merged_idx)
stats = get_stats(ids)
if stats and len(ids) >= 2:
pair = max(stats, key=stats.get)
else:
# no new merges found in this incoming data batch
print(f"\n\nstopping merges as no new byte pair found in the current batch")
break
# this most occurring pair not merged yet in any data batch
# generate a new token considering how many have been generated so far for the same tokenizer
idx = len(self.vocab) + 1
# update current new generated tokens to add to self.merge_counter later
current_batch_merge_counter += 1
# replace all occurrences of `pair` above in `ids` with NEW `idx` token, add this one to merges & vocab
# Note: this pair has never been seen for merging
ids = merge(ids, pair, idx)
self.merges[pair] = idx
self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
if verbose:
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {stats[pair]} count")
self.merge_counter += current_batch_merge_counter
def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def encode(self, text):
# input a string text, returns the token ids
text_bytes = text.encode("utf-8")
ids = list(text_bytes)
while len(ids) >= 2:
# here find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# if no merges i.e. the pair is not in merges dict,
# the key will result in an `inf` for every single pair,
# and the min will be just the first pair in the list,
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise merge the best pair NOTE: (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
|