dalybuilds commited on
Commit
3031c65
·
verified ·
1 Parent(s): a4102d1

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +26 -0
train.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 4. train.py
2
+ # Fine-tuning logic
3
+ from transformers import Trainer, TrainingArguments
4
+ from model_utils import BugClassifier
5
+
6
+ def fine_tune_model(dataset):
7
+ model = BugClassifier().model
8
+ training_args = TrainingArguments(
9
+ output_dir="./results",
10
+ evaluation_strategy="epoch",
11
+ save_strategy="epoch",
12
+ per_device_train_batch_size=8,
13
+ num_train_epochs=3,
14
+ logging_dir="./logs",
15
+ logging_steps=10,
16
+ )
17
+
18
+ trainer = Trainer(
19
+ model=model,
20
+ args=training_args,
21
+ train_dataset=dataset["train"],
22
+ eval_dataset=dataset["test"],
23
+ )
24
+
25
+ trainer.train()
26
+ model.save_pretrained("./fine_tuned_model")