File size: 3,861 Bytes
21c1d52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from transformers import (
    LlamaConfig, LlamaForSequenceClassification, LlamaForCausalLM,
    GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel,
    PreTrainedTokenizerFast
)
from tokenizers import Tokenizer
from tokenizers.models import BPE

from src.const import ACTION_SPACE, VOCAB

class RookTokenizer(PreTrainedTokenizerFast):
    # TODO: make it easier to use checkpoints from the hub
    # https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
    def __call__(self, *args, **kwargs):
        kwargs["return_token_type_ids"] = False
        return super().__call__(*args, **kwargs)

def make_model(config_dict, arch="llama"):
    if config_dict["finetuning_task"] == "text-classification":
        return make_model_clf(config_dict, arch=arch)
    elif config_dict["finetuning_task"] == "text-generation":
        return make_model_lm(config_dict, arch=arch)
    else:
        raise ValueError(f"Unknown config finetuning_task: {config_dict['finetuning_task']}")

def make_model_clf(config_dict, arch):
    if arch == "llama":
        Config = LlamaConfig
        Model = LlamaForSequenceClassification
    if arch == "gpt2":
        Config = GPT2Config
        Model = GPT2ForSequenceClassification

    # pad to multiple of 128
    config_dict["vocab_size"] = ((len(VOCAB) + 127) // 128) * 128
    config = Config(**config_dict)
    label_to_id = {v: i for i, v in enumerate(ACTION_SPACE)}
    config.num_labels = len(ACTION_SPACE)
    config.label2id = label_to_id
    config.id2label = {id: label for label, id in label_to_id.items()}
    model = Model(config=config)
    return model

def make_model_lm(config_dict, arch):
    if arch == "llama":
        Config = LlamaConfig
        Model = LlamaForCausalLM
    if arch == "gpt2":
        Config = GPT2Config
        Model = GPT2LMHeadModel
    # pad to multiple of 128
    config_dict["vocab_size"] = ((len(VOCAB) + len(ACTION_SPACE) + 4 + 127) // 128) * 128
    config = Config(**config_dict)
    model = Model(config=config)
    return model


def make_tokenizer(task="clf"):
    if task == "clf":
        return make_tokenizer_clf(model_max_length=78)
    elif task == "lm":
        return make_tokenizer_lm(model_max_length=79)
    elif task == "lm-cot":
        return make_tokenizer_lm(model_max_length=116)
    else:
        raise ValueError(f"Unknown task: {task}")
    
def make_tokenizer_clf(model_max_length):
    single_char_vocab = [e for e in VOCAB if len(e) == 1]
    multi_char_vocab = [e for e in VOCAB if len(e) > 1]
    merges = [tuple(e) for e in multi_char_vocab]
    print(merges[:5])

    tokenizer = Tokenizer(BPE(
        vocab=dict(zip(single_char_vocab, range(len(single_char_vocab)))), 
        merges=merges)
    )

    fast_tokenizer = RookTokenizer(
        tokenizer_object=tokenizer,
        model_max_length=model_max_length,
        pad_token="[PAD]",
        cls_token="[CLS]",
        sep_token="[SEP]",
        mask_token="[MASK]",
        clean_up_tokenization_spaces=False
    )
    return fast_tokenizer

def make_tokenizer_lm(model_max_length):
    vocab = VOCAB + ACTION_SPACE
    vocab += ["[OPTIONS]", "[VALUES]", "[ACTION]", "0000"]
    
    single_char_vocab = [e for e in vocab if len(e) == 1]
    multi_char_vocab = [e for e in vocab if len(e) > 1]
    merges = []

    tokenizer = Tokenizer(BPE(
        vocab=dict(zip(single_char_vocab, range(len(single_char_vocab)))), 
        merges=merges)
    )
    tokenizer.add_special_tokens(multi_char_vocab)

    fast_tokenizer = RookTokenizer(
        tokenizer_object=tokenizer,
        model_max_length=model_max_length,
        pad_token="[PAD]",
        cls_token="[CLS]",
        sep_token="[SEP]",
        mask_token="[MASK]",
        clean_up_tokenization_spaces=False
    )
    return fast_tokenizer