nroggendorff commited on
Commit
3304e16
·
verified ·
1 Parent(s): e728d3c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +20 -9
train.py CHANGED
@@ -16,6 +16,7 @@ VOCAB_SIZE = 32000
16
  INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
17
  INSTRUCT_DATASET = "nroggendorff/elephant"
18
  OUTPUT_REPO = "smallama"
 
19
  FP16 = False
20
  WARMUP_STEPS = 0
21
  DECAY = 0
@@ -24,9 +25,9 @@ PUSH_TO_HUB = True
24
 
25
  def load_data():
26
  pretrain = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
27
- pretrain = Dataset.from_generator(lambda: pretrain.take(int(3e+4)))
28
  instruct = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
29
- instruct = Dataset.from_generator(lambda: instruct.take(int(5e+4)))
30
  dataset_dict = DatasetDict({
31
  'pretrain': pretrain,
32
  'instruct': instruct
@@ -91,6 +92,10 @@ def create_model(tokenizer):
91
  model = LlamaForCausalLM(config)
92
  return model
93
 
 
 
 
 
94
  def configure_tokenizer(tokenizer):
95
  special_tokens = {
96
  "bos_token": "<s>",
@@ -145,7 +150,10 @@ def train_model(model, tokenizer, dataset, push, isinst):
145
  trained_tokenizer = trainer.tokenizer
146
 
147
  if push:
148
- repo_id = OUTPUT_REPO
 
 
 
149
  msg = str(train.training_loss)
150
  trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
151
  trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
@@ -153,17 +161,20 @@ def train_model(model, tokenizer, dataset, push, isinst):
153
  trained_model.save_pretrained("model")
154
  trained_tokenizer.save_pretrained("tokenizer")
155
 
156
- def main(push_to_hub=True):
157
  dataset = load_data()
158
  pretrain = dataset['pretrain']
159
  instruct = dataset['instruct']
160
  training_corpus = get_training_corpus(dataset)
161
  tokenizer = create_tokenizer(training_corpus)
162
- configure_tokenizer(tokenizer)
163
- model = create_model(tokenizer)
164
- train_model(model, tokenizer, pretrain, False, False)
165
- train_model(model, tokenizer, instruct, push_to_hub, True)
 
 
 
166
 
167
  if __name__ == "__main__":
168
- main(PUSH_TO_HUB)
169
  raise RuntimeError("The script is finished.")
 
16
  INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
17
  INSTRUCT_DATASET = "nroggendorff/elephant"
18
  OUTPUT_REPO = "smallama"
19
+ INSTRUCT_FINETUNE_BOOL = False
20
  FP16 = False
21
  WARMUP_STEPS = 0
22
  DECAY = 0
 
25
 
26
  def load_data():
27
  pretrain = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
28
+ pretrain = Dataset.from_generator(lambda: pretrain.take(int(3e+5)))
29
  instruct = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
30
+ instruct = Dataset.from_generator(lambda: instruct.take(int(5e+5)))
31
  dataset_dict = DatasetDict({
32
  'pretrain': pretrain,
33
  'instruct': instruct
 
92
  model = LlamaForCausalLM(config)
93
  return model
94
 
95
+ def load_model():
96
+ model = LlamaForCausalLM.from_pretrained(OUTPUT_REPO)
97
+ return model
98
+
99
  def configure_tokenizer(tokenizer):
100
  special_tokens = {
101
  "bos_token": "<s>",
 
150
  trained_tokenizer = trainer.tokenizer
151
 
152
  if push:
153
+ if INSTRUCT_FINETUNE_BOOL:
154
+ repo_id = OUTPUT_REPO + "-it"
155
+ else:
156
+ repo_id = OUTPUT_REPO
157
  msg = str(train.training_loss)
158
  trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
159
  trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
 
161
  trained_model.save_pretrained("model")
162
  trained_tokenizer.save_pretrained("tokenizer")
163
 
164
+ def main(push_to_hub=True, is_inst_finetune):
165
  dataset = load_data()
166
  pretrain = dataset['pretrain']
167
  instruct = dataset['instruct']
168
  training_corpus = get_training_corpus(dataset)
169
  tokenizer = create_tokenizer(training_corpus)
170
+ if is_inst_finetune:
171
+ configure_tokenizer(tokenizer)
172
+ model = load_model()
173
+ train_model(model, tokenizer, instruct, push_to_hub, True)
174
+ else:
175
+ model = create_model(tokenizer)
176
+ train_model(model, tokenizer, pretrain, push_to_hub, False)
177
 
178
  if __name__ == "__main__":
179
+ main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
180
  raise RuntimeError("The script is finished.")