File size: 2,390 Bytes
39dbdf0 e2e74c5 39dbdf0 e2e74c5 39dbdf0 6da5491 39dbdf0 6da5491 e2e74c5 39dbdf0 6da5491 39dbdf0 6da5491 39dbdf0 6da5491 39dbdf0 6da5491 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from transformers import set_seed
# Set seed for reproducibility
set_seed(42)
def fine_tune_model(dataset_url, model_name, epochs, batch_size, learning_rate):
"""
Fine-tunes a pre-trained transformer model on a custom dataset.
Parameters:
- dataset_url (str): URL or path to the dataset.
- model_name (str): Name of the pre-trained model.
- epochs (int): Number of training epochs.
- batch_size (int): Batch size for training.
- learning_rate (float): Learning rate for the optimizer.
Returns:
- dict: Status message containing training completion status.
"""
# Load the dataset
dataset = load_dataset(dataset_url)
# Load the pre-trained model for sequence classification (2 labels for binary classification)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Define the training arguments
training_args = TrainingArguments(
output_dir='./results', # Directory for storing results
num_train_epochs=epochs, # Number of training epochs
per_device_train_batch_size=batch_size, # Batch size for training
learning_rate=learning_rate, # Learning rate for the optimizer
logging_dir='./logs', # Directory for storing logs
logging_steps=10, # Log every 10 steps
evaluation_strategy="epoch", # Evaluate every epoch
save_strategy="epoch", # Save checkpoint every epoch
load_best_model_at_end=True, # Load the best model at the end of training
metric_for_best_model="accuracy", # Metric to monitor for selecting the best model
greater_is_better=True, # Set to True if higher metric values are better
)
# Initialize the Trainer with the model, arguments, and dataset
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'], # Training dataset
eval_dataset=dataset['validation'], # Validation dataset
)
# Train the model
trainer.train()
# Return a status message after training completes
return {"status": "Training complete"} |