nroggendorff commited on
Commit
0f9e3cc
·
verified ·
1 Parent(s): b508b1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py CHANGED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
2
+ import numpy as np
3
+ from transformers import LlamaConfig, LlamaForCausalLM
4
+ import trl
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import PreTrainedTokenizerFast
8
+ import requests as rq
9
+ import gc
10
+ from tokenizers import ByteLevelBPETokenizer
11
+
12
+ dataset = load_dataset("nroggendorff/openhermes", split="train")#.select(range(int(5e+4)))
13
+
14
+ def get_training_corpus():
15
+ for i in range(0, len(dataset), 1000):
16
+ yield dataset[i : i + 1000]["text"]
17
+
18
+ training_corpus = get_training_corpus()
19
+
20
+ tokenizer = ByteLevelBPETokenizer()
21
+
22
+ tokenizer.train_from_iterator(
23
+ training_corpus,
24
+ vocab_size=3200,
25
+ min_frequency=2,
26
+ special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>", "<|user|>", "<|bot|>", "<|end|>"]
27
+ )
28
+
29
+ tokenizer.save("custom_tokenizer.json")
30
+
31
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file="custom_tokenizer.json")
32
+
33
+ tokenizer.bos_token = "<s>"
34
+ tokenizer.eos_token = "</s>"
35
+ tokenizer.unk_token = "<unk>"
36
+ tokenizer.pad_token = "<pad>"
37
+ tokenizer.mask_token = "<mask>"
38
+
39
+ tokenizer.additional_special_tokens = ["<|user|>", "<|bot|>", "<|end|>"]
40
+
41
+ tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
42
+ tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
43
+
44
+ chat_template = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
45
+
46
+ tokenizer.chat_template = chat_template
47
+
48
+ tokenizer.add_special_tokens({
49
+ "additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
50
+ })
51
+
52
+ tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
53
+ tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
54
+
55
+ tokenizer.save_pretrained("llama-tokenizer")
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained("llama-tokenizer")
58
+ print(tokenizer.apply_chat_template([{"role": "user", "content": "Why is the sky blue?"}, {"role": "assistant", "content": "Due to rayleigh scattering."}, {"role": "user", "content": "That's cool."}, {"role": "assistant", "content": "Yeah, I agree."}], tokenize=False))
59
+
60
+ config = LlamaConfig(
61
+ vocab_size=tokenizer.vocab_size,
62
+ hidden_size=int(512 / 1),
63
+ intermediate_size=int(1024 / 1),
64
+ num_hidden_layers=int(8 / 1),
65
+ num_attention_heads=int(8 / 1),
66
+ max_position_embeddings=int(512 / 1),
67
+ rms_norm_eps=1e-6,
68
+ initializer_range=0.02,
69
+ use_cache=True,
70
+ pad_token_id=tokenizer.pad_token_id,
71
+ bos_token_id=tokenizer.bos_token_id,
72
+ eos_token_id=tokenizer.eos_token_id,
73
+ tie_word_embeddings=False,
74
+ )
75
+
76
+ model = LlamaForCausalLM(config)
77
+
78
+ def format_prompts(examples):
79
+ texts = []
80
+ for text in examples['text']:
81
+ conversation = []
82
+ parts = text.split('<|end|>')
83
+ for i in range(0, len(parts) - 1, 2):
84
+ prompt = parts[i].replace("<|user|>", "")
85
+ response = parts[i + 1].replace("<|bot|>", "")
86
+ conversation.append({"role": "user", "content": prompt})
87
+ conversation.append({"role": "assistant", "content": response})
88
+ formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
89
+ texts.append(formatted_conversation)
90
+ output = {}
91
+ output['text'] = texts
92
+ return output
93
+
94
+ dataset = dataset.map(format_prompts, batched=True)
95
+
96
+ print(dataset['text'][2])
97
+
98
+ args = TrainingArguments(
99
+ output_dir="mayo",
100
+ num_train_epochs=4,
101
+ gradient_accumulation_steps=4,
102
+ per_device_train_batch_size=1,
103
+ learning_rate=1e-5,
104
+ save_steps=100000,
105
+ fp16=True,
106
+ optim="sgd",
107
+ optim_target_modules=["attn", "mlp"],
108
+ max_grad_norm=0.3
109
+ )
110
+
111
+ trainer = trl.SFTTrainer(
112
+ model=model,
113
+ tokenizer=tokenizer,
114
+ args=args,
115
+ train_dataset=dataset,
116
+ dataset_text_field='text',
117
+ max_seq_length=512,
118
+ )
119
+
120
+ torch.cuda.set_device(0)
121
+
122
+ gc.collect()
123
+ torch.cuda.empty_cache()
124
+
125
+ try:
126
+ trainer.train()
127
+ except Exception as e:
128
+ rq.post("https://discord.com/api/webhooks/1245084721923358730/pVHUf2PR4Wst52KVNxVSeAHnSIKxx-PLdd90OHASegb30cNoGZe9N476LzCDVLQXDbT0", json={"content": str(e)})
129
+
130
+ #trainer.push_to_hub()
131
+ trained_model = trainer.model
132
+ trained_tokenizer = trainer.tokenizer
133
+
134
+ repo_id = "makeshift-mayo"
135
+ trained_model.push_to_hub(repo_id)
136
+ trained_tokenizer.push_to_hub(repo_id)
137
+
138
+ rq.post("https://discord.com/api/webhooks/1245084721923358730/pVHUf2PR4Wst52KVNxVSeAHnSIKxx-PLdd90OHASegb30cNoGZe9N476LzCDVLQXDbT0", json={"content": "that shit is finally done"})