nroggendorff commited on
Commit
84b97dd
·
verified ·
1 Parent(s): d2b80b3

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -3
train.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import torch
4
  import trl
5
 
6
- from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_linear_schedule_with_warmup
7
  from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
@@ -108,10 +108,10 @@ def train_model(model, tokenizer, dataset, push):
108
  )
109
 
110
  optimizer = AdamW(model.parameters(), lr=args.learning_rate)
111
- scheduler = get_linear_schedule_with_warmup(
112
  optimizer,
113
  num_warmup_steps=args.warmup_steps,
114
- num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
115
  )
116
 
117
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
 
3
  import torch
4
  import trl
5
 
6
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
7
  from datasets import load_dataset
8
  from tokenizers import ByteLevelBPETokenizer
9
 
 
108
  )
109
 
110
  optimizer = AdamW(model.parameters(), lr=args.learning_rate)
111
+ scheduler = get_cosine_schedule_with_warmup(
112
  optimizer,
113
  num_warmup_steps=args.warmup_steps,
114
+ num_training_steps=(len(dataset) // args.per_device_train_batch_size) * args.num_train_epochs
115
  )
116
 
117
  dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)