Spaces:
Running
Running
File size: 4,918 Bytes
7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 1c817fd 7b74407 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import json, re, unicodedata
from functools import lru_cache
import wget, os
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
import nltk
@lru_cache()
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]; n = 0
for b in range(2**8): if b not in bs: bs.append(b); cs.append(2**8 + n); n += 1
cs = [chr(n) for n in cs]; return dict(zip(bs, cs))
def get_pairs(word):
pairs = set(); prev_char = word[0]
for char in word[1:]: pairs.add((prev_char, char)); prev_char = char; return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
self.encoder = encoder; self.decoder = {v:k for k,v in self.encoder.items()}; self.errors = errors
self.byte_encoder = bytes_to_unicode(); self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {};
if tokenize is None: self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE); self.tokenize = lambda text: re.findall(self.pat, text)
else: self.tokenize = tokenize
def bpe(self, token):
if token in self.cache: return self.cache[token]
word = tuple(token); pairs = get_pairs(word)
if not pairs: return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: break
first, second = bigram; new_word = []; i = 0
while i < len(word):
try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
except ValueError: new_word.extend(word[i:]); break
if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
else: new_word.append(word[i]); i += 1
new_word = tuple(new_word); word = new_word
if len(word) == 1: break
else: pairs = get_pairs(word)
word = ' '.join(word); self.cache[token] = word; return word
def encode(self, text):
bpe_tokens = []; normalized_text = unicodedata.normalize('NFKC', text); normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t'); normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
for token in self.tokenize(normalized_text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore')); bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens]); text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
sentences = nltk.sent_tokenize(decoded_text); return ' '.join(sentences).replace("<br>", "<br>\n")
def get_encoder_gpt2():
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE); vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
if not os.path.exists(GPT2_FOLDER): os.makedirs(GPT2_FOLDER)
if not os.path.exists(encoder_path): wget.download(ENCODER_URL, out=encoder_path)
if not os.path.exists(vocab_path): wget.download(VOCAB_URL, out=vocab_path)
with open(encoder_path, 'r') as f: encoder = json.load(f)
with open(vocab_path, 'r', encoding="utf-8") as f: bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]; encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder); encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN; return encoder_obj
def get_codegen_tokenizer_pure(vocab_file, merges_file):
vocab = json.load(open(vocab_file)); merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]; bpe_merges = [tuple(m.split()) for m in merges]
byte_encoder = bytes_to_unicode(); byte_decoder = {v: k for k, v in byte_encoder.items()}
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'''); tokenize = lambda text: re.findall(tokenizer_regex, text)
encoder_obj = Encoder(encoder=vocab, bpe_merges=bpe_merges, byte_encoder=byte_encoder, byte_decoder=byte_decoder, tokenize=tokenize); return encoder_obj
def codegen_tokenize(text, tokenizer): return tokenizer.encode(text)
def codegen_decode(tokens, tokenizer): return tokenizer.decode(tokens)
|