Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
import torch
|
4 |
import trl
|
5 |
|
6 |
-
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW,
|
7 |
from datasets import load_dataset
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
@@ -108,10 +108,10 @@ def train_model(model, tokenizer, dataset, push):
|
|
108 |
)
|
109 |
|
110 |
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
111 |
-
scheduler =
|
112 |
optimizer,
|
113 |
num_warmup_steps=args.warmup_steps,
|
114 |
-
num_training_steps=len(dataset)
|
115 |
)
|
116 |
|
117 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
|
|
3 |
import torch
|
4 |
import trl
|
5 |
|
6 |
+
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
|
7 |
from datasets import load_dataset
|
8 |
from tokenizers import ByteLevelBPETokenizer
|
9 |
|
|
|
108 |
)
|
109 |
|
110 |
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
111 |
+
scheduler = get_cosine_schedule_with_warmup(
|
112 |
optimizer,
|
113 |
num_warmup_steps=args.warmup_steps,
|
114 |
+
num_training_steps=(len(dataset) // args.per_device_train_batch_size) * args.num_train_epochs
|
115 |
)
|
116 |
|
117 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|