|
import itertools |
|
import re |
|
from typing import List, Optional, Tuple |
|
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
class DNAKmerTokenizer(PreTrainedTokenizer): |
|
def __init__(self, k, **kwargs): |
|
self.k = k |
|
self.special_tokens = [ |
|
"<oov>", |
|
"<s>", |
|
"</s>", |
|
"<pad>", |
|
"<mask>", |
|
"<bog>", |
|
"<eog>", |
|
"<bok>", |
|
"<eok>", |
|
"<+>", |
|
"<->", |
|
"<mam>", |
|
"<vrt>", |
|
"<inv>", |
|
"<pln>", |
|
"<fng>", |
|
"<prt>", |
|
"<cds>", |
|
"<pseudo>", |
|
"<tRNA>", |
|
"<rRNA>", |
|
"<ncRNA>", |
|
"<misc_RNA>", |
|
"<sp0>", |
|
"<sp1>", |
|
"<sp2>", |
|
"<sp3>", |
|
"<sp4>", |
|
"<sp5>", |
|
"<sp6>", |
|
"<sp7>", |
|
"<sp8>", |
|
] |
|
self.kmers = [ |
|
"".join(kmer) for kmer in itertools.product("ATCG", repeat=self.k) |
|
] |
|
self.vocab = { |
|
token: i for i, token in enumerate(self.special_tokens + self.kmers) |
|
} |
|
self.ids_to_tokens = {v: k for k, v in self.vocab.items()} |
|
self.special_token_pattern = re.compile( |
|
"|".join(re.escape(token) for token in self.special_tokens) |
|
) |
|
self.dna_pattern = re.compile(f"[A-Z]{{{self.k}}}|[A-Z]+") |
|
self.bos_token = "<s>" |
|
self.eos_token = "</s>" |
|
self.bos_token_id = self._convert_token_to_id(self.bos_token) |
|
self.eos_token_id = self._convert_token_to_id(self.eos_token) |
|
super().__init__(**kwargs) |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.vocab) |
|
|
|
def get_vocab(self): |
|
return dict(self.vocab) |
|
|
|
def _tokenize(self, text, **kwargs) -> List[str]: |
|
tokens = [] |
|
pos = 0 |
|
while pos < len(text): |
|
special_match = self.special_token_pattern.match(text, pos) |
|
if special_match: |
|
tokens.append(special_match.group()) |
|
pos = special_match.end() |
|
else: |
|
dna_match = self.dna_pattern.match(text, pos) |
|
if dna_match: |
|
dna_seq = dna_match.group() |
|
tokens.append(dna_seq) |
|
pos = dna_match.end() |
|
else: |
|
tokens.append(text[pos]) |
|
pos += 1 |
|
return tokens |
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
return self.vocab.get(token, self.vocab["<oov>"]) |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self.ids_to_tokens.get(index, "<oov>") |
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
return "".join(tokens) |
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
if token_ids_1 is None: |
|
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
|
return ( |
|
[self.bos_token_id] |
|
+ token_ids_0 |
|
+ [self.eos_token_id] |
|
+ token_ids_1 |
|
+ [self.eos_token_id] |
|
) |
|
|
|
def get_special_tokens_mask( |
|
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False |
|
): |
|
if already_has_special_tokens: |
|
return super().get_special_tokens_mask( |
|
token_ids_0, token_ids_1, already_has_special_tokens=True |
|
) |
|
if token_ids_1 is None: |
|
return [1] + ([0] * len(token_ids_0)) + [1] |
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
|
|
|
def prepare_for_model(self, *args, **kwargs): |
|
encoding = super().prepare_for_model(*args, **kwargs) |
|
if "token_type_ids" in encoding: |
|
del encoding["token_type_ids"] |
|
return encoding |
|
|
|
def save_vocabulary( |
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
) -> Tuple[str]: |
|
import os |
|
|
|
vocab_file = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt", |
|
) |
|
with open(vocab_file, "w", encoding="utf-8") as writer: |
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): |
|
writer.write(token + "\n") |
|
return (vocab_file,) |
|
|