360Zhinao-7B-Chat-4K / tokenization_zhinao.py
zhaicunqi's picture
Update tokenization_zhinao.py
fd64099 verified
import os
import torch
import base64
import tiktoken
from typing import Collection, Optional, Dict, List, Set, Tuple, Union
from transformers import PreTrainedTokenizer
from transformers.utils import PaddingStrategy
from transformers.tokenization_utils import PreTrainedTokenizer
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class SPTokenizer:
def __init__(self, model_path):
self.vocab_file = model_path
self.pad_token = '<pad>'
self.unk_token = '<unk>'
self.mask_token = '<mask>'
self.eod_token = '<eod>'
self.eop_token = '<eop>'
self.im_start_token = '<|im_start|>'
self.im_end_token = '<|im_end|>'
## special_tokens
self.SPECIAL_TOKENS = (
self.pad_token,
self.unk_token,
self.mask_token,
self.eod_token,
self.eop_token,
'[space2]', '[space3]', '[space4]', '[space8]',
self.im_start_token, self.im_end_token
)
self.bulid_tokenizer()
self.out = self.output_core_token()
self.token2strs = {
"[space2]": " ",
"[space3]": " ",
"[space4]": " ",
"[space8]": " ",
}
self.str2tokens = {v: k for k, v in self.token2strs.items()}
self.sorted_strs = sorted(list(self.str2tokens.keys()),
key=lambda x: len(x), reverse=True)
## skip_special_tokens
self.decode_skip_special_tokens = [
self.pad_token,
self.unk_token,
self.mask_token,
self.eod_token,
self.eop_token,
self.im_start_token,
self.im_end_token]
self.decode_skip_special_tokens_ids = [self.convert_token_to_id(token) for token in self.decode_skip_special_tokens]
def _load_tiktoken_bpe(self, tiktoken_bpe_file: str):
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def bulid_tokenizer(self):
mergeable_ranks = self._load_tiktoken_bpe(self.vocab_file)
special_tokens = {
token: index
for index, token in enumerate(
self.SPECIAL_TOKENS, start=len(mergeable_ranks)
)
}
encode = tiktoken.Encoding(
"zhinao",
pat_str=PAT_STR,
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens
)
decoder = {v: k for k, v in mergeable_ranks.items()}
decoder.update({v: k for k, v in special_tokens.items()})
decoder_token2id = {v: k for k, v in decoder.items()}
self.tokenizer = encode
self.decoder = decoder
self.decoder_token2id = decoder_token2id
self.num_tokens = len(mergeable_ranks) + len(self.SPECIAL_TOKENS)
def output_core_token(self):
"""output special tokens"""
out = {}
for t in self.SPECIAL_TOKENS:
out[t] = self.convert_token_to_id(t)
return out
def tokenize(
self,
text,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = ()):
tokens = []
text = self.convert(text)
for idx in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
tokens.append(self.decoder[idx])
return tokens
def encode(self, text, allowed_special="all", disallowed_special=()):
"""text to id"""
text = self.convert(text)
return self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
def decode(self, ids, errors="replace"):
"""id to text"""
text = self.tokenizer.decode(ids, errors=errors)
return self.deconvert(text)
def decode_tokens(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors="replace")
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type bytes or str")
if temp:
text += temp.decode("utf-8", errors="replace")
return self.deconvert(text)
def convert_id_to_token(self, idx):
return self.decoder[idx]
def convert_token_to_id(self, token):
return self.decoder_token2id[token]
def convert(self, text):
"""将文本的特殊字符转换成特殊token"""
for k in ["[br]", "<br>"]:
text = text.replace(k, "\n")
for k in self.sorted_strs:
if k in text:
text = text.replace(k, self.str2tokens[k])
return text
def deconvert(self, text):
"""将解码文本恢复原始字符"""
for t in self.token2strs:
if t in text:
text = text.replace(t, self.token2strs[t])
return text
class ZhinaoTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab/360.tiktoken"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
self.name = "ZhinaoTokenizer"
self.errors = "replace"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(model_path=vocab_file)
try:
kwargs.pop('eos_token')
kwargs.pop('pad_token')
kwargs.pop('unk_token')
except:
pass
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
self.pad_token_id = self.tokenizer.convert_token_to_id(self.tokenizer.pad_token)
self.eod_id = self.tokenizer.convert_token_to_id(self.tokenizer.eod_token)
self.im_start_id = self.tokenizer.convert_token_to_id(self.tokenizer.im_start_token)
self.im_end_id = self.tokenizer.convert_token_to_id(self.tokenizer.im_end_token)
from icecream import ic
ic(
self.eos_token_id,
self.pad_token_id,
self.im_start_id,
self.im_end_id)
@property
def unk_token(self) -> str:
return self.tokenizer.unk_token
@property
def pad_token(self) -> str:
return self.tokenizer.pad_token
@property
def eos_token(self) -> str:
return self.tokenizer.eod_token
@property
def eos_token_id(self):
return self.tokenizer.convert_token_to_id(self.tokenizer.eod_token)
@property
def eop_token(self) -> str:
return self.tokenizer.eop_token
@property
def eop_token_id(self):
return self.tokenizer.convert_token_to_id(self.tokenizer.eop_token)
@property
def vocab_size(self):
return self.tokenizer.num_tokens
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
) -> List[Union[bytes, str]]:
tokens = []
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.tokenizer.decoder[t])
return tokens
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i not in self.tokenizer.decode_skip_special_tokens_ids]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
def _tokenize(self, text, **kwargs):
raise NotImplementedError
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab. """
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""Save only the vocabulary of the tokenizer (vocabulary). """
if os.path.isdir(save_directory):
vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
os.makedirs(save_directory + "/vocab", exist_ok=True)
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)