U-rara's picture
Upload model
58dcff1
import itertools
import re
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizer
class DNAKmerTokenizer(PreTrainedTokenizer):
def __init__(self, k, **kwargs):
self.k = k
self.special_tokens = [
"<oov>",
"<s>",
"</s>",
"<pad>",
"<mask>",
"<bog>",
"<eog>",
"<bok>",
"<eok>",
"<+>",
"<->",
"<mam>",
"<vrt>",
"<inv>",
"<pln>",
"<fng>",
"<prt>",
"<cds>",
"<pseudo>",
"<tRNA>",
"<rRNA>",
"<ncRNA>",
"<misc_RNA>",
"<sp0>",
"<sp1>",
"<sp2>",
"<sp3>",
"<sp4>",
"<sp5>",
"<sp6>",
"<sp7>",
"<sp8>",
]
self.kmers = [
"".join(kmer) for kmer in itertools.product("ATCG", repeat=self.k)
]
self.vocab = {
token: i for i, token in enumerate(self.special_tokens + self.kmers)
}
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
self.special_token_pattern = re.compile(
"|".join(re.escape(token) for token in self.special_tokens)
)
self.dna_pattern = re.compile(f"[A-Z]{{{self.k}}}|[A-Z]+")
self.bos_token = "<s>"
self.eos_token = "</s>"
self.bos_token_id = self._convert_token_to_id(self.bos_token)
self.eos_token_id = self._convert_token_to_id(self.eos_token)
super().__init__(**kwargs)
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab)
def _tokenize(self, text, **kwargs) -> List[str]:
tokens = []
pos = 0
while pos < len(text):
special_match = self.special_token_pattern.match(text, pos)
if special_match:
tokens.append(special_match.group())
pos = special_match.end()
else:
dna_match = self.dna_pattern.match(text, pos)
if dna_match:
dna_seq = dna_match.group()
tokens.append(dna_seq)
pos = dna_match.end()
else:
tokens.append(text[pos])
pos += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.vocab["<oov>"])
def _convert_id_to_token(self, index: int) -> str:
return self.ids_to_tokens.get(index, "<oov>")
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return "".join(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return (
[self.bos_token_id]
+ token_ids_0
+ [self.eos_token_id]
+ token_ids_1
+ [self.eos_token_id]
)
def get_special_tokens_mask(
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
):
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0, token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def prepare_for_model(self, *args, **kwargs):
encoding = super().prepare_for_model(*args, **kwargs)
if "token_type_ids" in encoding:
del encoding["token_type_ids"]
return encoding
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
import os
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
)
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
writer.write(token + "\n")
return (vocab_file,)