File size: 2,390 Bytes
39dbdf0
 
 
e2e74c5
39dbdf0
e2e74c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dbdf0
6da5491
 
39dbdf0
6da5491
 
 
 
 
 
e2e74c5
 
 
 
 
39dbdf0
6da5491
 
39dbdf0
 
 
6da5491
 
39dbdf0
6da5491
 
39dbdf0
6da5491
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from transformers import set_seed

# Set seed for reproducibility
set_seed(42)

def fine_tune_model(dataset_url, model_name, epochs, batch_size, learning_rate):
    """
    Fine-tunes a pre-trained transformer model on a custom dataset.

    Parameters:
    - dataset_url (str): URL or path to the dataset.
    - model_name (str): Name of the pre-trained model.
    - epochs (int): Number of training epochs.
    - batch_size (int): Batch size for training.
    - learning_rate (float): Learning rate for the optimizer.

    Returns:
    - dict: Status message containing training completion status.
    """
    
    # Load the dataset
    dataset = load_dataset(dataset_url)
    
    # Load the pre-trained model for sequence classification (2 labels for binary classification)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    # Define the training arguments
    training_args = TrainingArguments(
        output_dir='./results',                # Directory for storing results
        num_train_epochs=epochs,               # Number of training epochs
        per_device_train_batch_size=batch_size,  # Batch size for training
        learning_rate=learning_rate,          # Learning rate for the optimizer
        logging_dir='./logs',                  # Directory for storing logs
        logging_steps=10,                      # Log every 10 steps
        evaluation_strategy="epoch",           # Evaluate every epoch
        save_strategy="epoch",                 # Save checkpoint every epoch
        load_best_model_at_end=True,           # Load the best model at the end of training
        metric_for_best_model="accuracy",     # Metric to monitor for selecting the best model
        greater_is_better=True,                # Set to True if higher metric values are better
    )

    # Initialize the Trainer with the model, arguments, and dataset
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],        # Training dataset
        eval_dataset=dataset['validation'],    # Validation dataset
    )

    # Train the model
    trainer.train()

    # Return a status message after training completes
    return {"status": "Training complete"}