from pathlib import Path from typing import Any, List, Optional, Union from transformers import AutoTokenizer, PreTrainedTokenizerBase class TokenizerBase(PreTrainedTokenizerBase): ''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. ''' class TransformersTokenizer(TokenizerBase): ''' A wrapper for the Transformers' tokenizer. This is the default tokenizer for LLM. ''' def __init__(self, tokenizer): self.tokenizer = tokenizer def __call__(self, text: str, *args, **kwargs) -> Any: return self.tokenizer(text, *args, **kwargs) @property def eos_token_id(self) -> int: return self.tokenizer.eos_token_id @property def pad_token_id(self) -> int: return self.tokenizer.pad_token_id def encode(self, text: str, *args, **kwargs) -> List[int]: return self.tokenizer.encode(text, *args, **kwargs) def decode(self, token_ids: List[int], *args, **kwargs) -> str: return self.tokenizer.decode(token_ids, *args, **kwargs) def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict: return self.tokenizer.batch_encode_plus(texts, *args, **kwargs) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.tokenizer})" @classmethod def from_pretrained(cls, pretrained_model_dir: str, **kwargs): tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, **kwargs) return cls(tokenizer) def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase, TokenizerBase]] = None, **kwargs) -> Optional[TokenizerBase]: if obj is None: return None elif isinstance(obj, (str, Path)): default_kwargs = { 'legacy': False, 'padding_side': 'left', 'truncation_side': 'left', 'trust_remote_code': True, 'use_fast': True, } default_kwargs.update(kwargs) return TransformersTokenizer.from_pretrained(obj, **default_kwargs) elif isinstance(obj, PreTrainedTokenizerBase): return TransformersTokenizer(obj) elif isinstance(obj, TokenizerBase): return obj else: raise TypeError(f"Unrecognized tokenizer {obj}")