Transformers-Fine-Tuner / fine_tuner.py
Canstralian's picture
Update fine_tuner.py
6da5491 verified
raw
history blame
1.32 kB
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
# Load the pre-trained model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Define the training arguments
training_args = TrainingArguments(
output_dir='./results', # Directory for storing results
num_train_epochs=epochs, # Number of training epochs
per_device_train_batch_size=batch_size, # Batch size for training
learning_rate=learning_rate, # Learning rate for the optimizer
logging_dir='./logs', # Directory for storing logs
logging_steps=10, # Log every 10 steps
)
# Initialize the Trainer with the model, arguments, and dataset
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'], # Training dataset
eval_dataset=dataset['validation'], # Validation dataset
)
# Train the model
trainer.train()
# Return a status message after training completes
return {"status": "Training complete"}