import itertools import re from typing import List, Optional, Tuple from transformers import PreTrainedTokenizer class DNAKmerTokenizer(PreTrainedTokenizer): def __init__(self, k, **kwargs): self.k = k self.special_tokens = [ "", "", "", "", "", "", "", "", "", "<+>", "<->", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ] self.kmers = [ "".join(kmer) for kmer in itertools.product("ATCG", repeat=self.k) ] self.vocab = { token: i for i, token in enumerate(self.special_tokens + self.kmers) } self.ids_to_tokens = {v: k for k, v in self.vocab.items()} self.special_token_pattern = re.compile( "|".join(re.escape(token) for token in self.special_tokens) ) self.dna_pattern = re.compile(f"[A-Z]{{{self.k}}}|[A-Z]+") self.bos_token = "" self.eos_token = "" self.bos_token_id = self._convert_token_to_id(self.bos_token) self.eos_token_id = self._convert_token_to_id(self.eos_token) super().__init__(**kwargs) @property def vocab_size(self): return len(self.vocab) def get_vocab(self): return dict(self.vocab) def _tokenize(self, text, **kwargs) -> List[str]: tokens = [] pos = 0 while pos < len(text): special_match = self.special_token_pattern.match(text, pos) if special_match: tokens.append(special_match.group()) pos = special_match.end() else: dna_match = self.dna_pattern.match(text, pos) if dna_match: dna_seq = dna_match.group() tokens.append(dna_seq) pos = dna_match.end() else: tokens.append(text[pos]) pos += 1 return tokens def _convert_token_to_id(self, token: str) -> int: return self.vocab.get(token, self.vocab[""]) def _convert_id_to_token(self, index: int) -> str: return self.ids_to_tokens.get(index, "") def convert_tokens_to_string(self, tokens: List[str]) -> str: return "".join(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): if token_ids_1 is None: return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] return ( [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] ) def get_special_tokens_mask( self, token_ids_0, token_ids_1=None, already_has_special_tokens=False ): if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0, token_ids_1, already_has_special_tokens=True ) if token_ids_1 is None: return [1] + ([0] * len(token_ids_0)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] def prepare_for_model(self, *args, **kwargs): encoding = super().prepare_for_model(*args, **kwargs) if "token_type_ids" in encoding: del encoding["token_type_ids"] return encoding def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> Tuple[str]: import os vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt", ) with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): writer.write(token + "\n") return (vocab_file,)