File size: 2,392 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pathlib import Path
from typing import Any, List, Optional, Union

from transformers import AutoTokenizer, PreTrainedTokenizerBase


class TokenizerBase(PreTrainedTokenizerBase):
    ''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class.  '''


class TransformersTokenizer(TokenizerBase):
    ''' A wrapper for the Transformers' tokenizer.
    This is the default tokenizer for LLM. '''

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, text: str, *args, **kwargs) -> Any:
        return self.tokenizer(text, *args, **kwargs)

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_token_id

    def encode(self, text: str, *args, **kwargs) -> List[int]:
        return self.tokenizer.encode(text, *args, **kwargs)

    def decode(self, token_ids: List[int], *args, **kwargs) -> str:
        return self.tokenizer.decode(token_ids, *args, **kwargs)

    def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
        return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.tokenizer})"

    @classmethod
    def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
                                                  **kwargs)
        return cls(tokenizer)


def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase,
                                          TokenizerBase]] = None,
                      **kwargs) -> Optional[TokenizerBase]:
    if obj is None:
        return None
    elif isinstance(obj, (str, Path)):
        default_kwargs = {
            'legacy': False,
            'padding_side': 'left',
            'truncation_side': 'left',
            'trust_remote_code': True,
            'use_fast': True,
        }
        default_kwargs.update(kwargs)
        return TransformersTokenizer.from_pretrained(obj, **default_kwargs)
    elif isinstance(obj, PreTrainedTokenizerBase):
        return TransformersTokenizer(obj)
    elif isinstance(obj, TokenizerBase):
        return obj
    else:
        raise TypeError(f"Unrecognized tokenizer {obj}")