Update fine_tuner.py
Browse files- fine_tuner.py +18 -9
fine_tuner.py
CHANGED
@@ -3,20 +3,29 @@ from transformers import AutoModelForSequenceClassification, Trainer, TrainingAr
|
|
3 |
from datasets import load_dataset
|
4 |
|
5 |
def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
|
|
|
6 |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
|
|
|
|
7 |
training_args = TrainingArguments(
|
8 |
-
output_dir='./results',
|
9 |
-
num_train_epochs=epochs,
|
10 |
-
per_device_train_batch_size=batch_size,
|
11 |
-
learning_rate=learning_rate,
|
12 |
-
logging_dir='./logs',
|
13 |
-
logging_steps=10,
|
14 |
)
|
|
|
|
|
15 |
trainer = Trainer(
|
16 |
model=model,
|
17 |
args=training_args,
|
18 |
-
train_dataset=dataset['train'],
|
19 |
-
eval_dataset=dataset['validation'],
|
20 |
)
|
|
|
|
|
21 |
trainer.train()
|
22 |
-
|
|
|
|
|
|
3 |
from datasets import load_dataset
|
4 |
|
5 |
def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
|
6 |
+
# Load the pre-trained model for sequence classification
|
7 |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
8 |
+
|
9 |
+
# Define the training arguments
|
10 |
training_args = TrainingArguments(
|
11 |
+
output_dir='./results', # Directory for storing results
|
12 |
+
num_train_epochs=epochs, # Number of training epochs
|
13 |
+
per_device_train_batch_size=batch_size, # Batch size for training
|
14 |
+
learning_rate=learning_rate, # Learning rate for the optimizer
|
15 |
+
logging_dir='./logs', # Directory for storing logs
|
16 |
+
logging_steps=10, # Log every 10 steps
|
17 |
)
|
18 |
+
|
19 |
+
# Initialize the Trainer with the model, arguments, and dataset
|
20 |
trainer = Trainer(
|
21 |
model=model,
|
22 |
args=training_args,
|
23 |
+
train_dataset=dataset['train'], # Training dataset
|
24 |
+
eval_dataset=dataset['validation'], # Validation dataset
|
25 |
)
|
26 |
+
|
27 |
+
# Train the model
|
28 |
trainer.train()
|
29 |
+
|
30 |
+
# Return a status message after training completes
|
31 |
+
return {"status": "Training complete"}
|