File size: 4,918 Bytes
7b74407
1c817fd
7b74407
 
1c817fd
 
 
 
 
7b74407
 
 
1c817fd
 
7b74407
 
1c817fd
 
 
7b74407
 
1c817fd
7b74407
 
 
1c817fd
 
7b74407
 
 
1c817fd
 
7b74407
 
1c817fd
7b74407
 
 
 
 
 
 
 
1c817fd
 
7b74407
 
1c817fd
 
 
7b74407
1c817fd
7b74407
1c817fd
 
7b74407
 
 
 
 
 
 
 
1c817fd
 
7b74407
 
 
 
1c817fd
7b74407
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json, re, unicodedata
from functools import lru_cache
import wget, os
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
import nltk

@lru_cache()
def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]; n = 0
    for b in range(2**8): if b not in bs: bs.append(b); cs.append(2**8 + n); n += 1
    cs = [chr(n) for n in cs]; return dict(zip(bs, cs))

def get_pairs(word):
    pairs = set(); prev_char = word[0]
    for char in word[1:]: pairs.add((prev_char, char)); prev_char = char; return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
        self.encoder = encoder; self.decoder = {v:k for k,v in self.encoder.items()}; self.errors = errors
        self.byte_encoder = bytes_to_unicode(); self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}; 
        if tokenize is None: self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE); self.tokenize = lambda text: re.findall(self.pat, text)
        else: self.tokenize = tokenize

    def bpe(self, token):
        if token in self.cache: return self.cache[token]
        word = tuple(token); pairs = get_pairs(word)
        if not pairs: return token
        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks: break
            first, second = bigram; new_word = []; i = 0
            while i < len(word):
                try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
                except ValueError: new_word.extend(word[i:]); break
                if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
                else: new_word.append(word[i]); i += 1
            new_word = tuple(new_word); word = new_word
            if len(word) == 1: break
            else: pairs = get_pairs(word)
        word = ' '.join(word); self.cache[token] = word; return word

    def encode(self, text):
        bpe_tokens = []; normalized_text = unicodedata.normalize('NFKC', text); normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t'); normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
        for token in self.tokenize(normalized_text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore')); bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens]); text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
        decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
        sentences = nltk.sent_tokenize(decoded_text); return ' '.join(sentences).replace("<br>", "<br>\n")

def get_encoder_gpt2():
    encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE); vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
    if not os.path.exists(GPT2_FOLDER): os.makedirs(GPT2_FOLDER)
    if not os.path.exists(encoder_path): wget.download(ENCODER_URL, out=encoder_path)
    if not os.path.exists(vocab_path): wget.download(VOCAB_URL, out=vocab_path)
    with open(encoder_path, 'r') as f: encoder = json.load(f)
    with open(vocab_path, 'r', encoding="utf-8") as f: bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]; encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
    encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder); encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN; return encoder_obj

def get_codegen_tokenizer_pure(vocab_file, merges_file):
    vocab = json.load(open(vocab_file)); merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]; bpe_merges = [tuple(m.split()) for m in merges]
    byte_encoder = bytes_to_unicode(); byte_decoder = {v: k for k, v in byte_encoder.items()}
    tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'''); tokenize = lambda text: re.findall(tokenizer_regex, text)
    encoder_obj = Encoder(encoder=vocab, bpe_merges=bpe_merges, byte_encoder=byte_encoder, byte_decoder=byte_decoder, tokenize=tokenize); return encoder_obj

def codegen_tokenize(text, tokenizer): return tokenizer.encode(text)
def codegen_decode(tokens, tokenizer): return tokenizer.decode(tokens)