File size: 4,772 Bytes
cdcc3ac
 
 
 
 
5db7813
1964ece
5db7813
 
cdcc3ac
 
 
d57f40c
cdcc3ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d22962
cdcc3ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d57f40c
cdcc3ac
 
 
 
1964ece
cdcc3ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db7813
6b5fdb9
 
2d22962
cdcc3ac
1c28db7
cdcc3ac
 
 
 
 
5db7813
cdcc3ac
 
 
 
 
 
 
 
 
 
e12fb70
cdcc3ac
 
 
c112e25
1c28db7
1964ece
 
5db7813
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from utils import get_dataset, serialize_data, plot_training_history, get_model_accuracy


def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    num_training_steps = num_epochs * len(train_loader)
    progress_bar = tqdm(range(num_training_steps))

    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []

    best_valid_loss = float("inf")
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            
            train_loss += loss.item()
            
            _, predicted_train = torch.max(outputs.logits, 1)
            labels_train = batch["labels"]
            correct_train += (predicted_train == labels_train).sum().item()
            total_train += labels_train.size(0)
        
        train_accuracy = correct_train / total_train
        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(train_accuracy)
        
        model.eval()
        valid_loss = 0.0
        correct_valid = 0
        total_valid = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
                valid_loss += loss.item()
                
                _, predicted_valid = torch.max(outputs.logits, 1)
                labels_valid = batch["labels"]
                correct_valid += (predicted_valid == labels_valid).sum().item()
                total_valid += labels_valid.size(0)
        
        valid_loss /= len(val_loader)
        valid_losses.append(valid_loss)
        
        valid_accuracy = correct_valid / total_valid
        valid_accuracies.append(valid_accuracy)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')    

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            epochs_no_improve = 0
            best_model = model.state_dict()
            torch.save(best_model, "../models/best_model_checkpoint.pth")
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f'Early stopping after {epoch+1} epochs with no improvement.')
                break
                
    model.load_state_dict(best_model)
    history = {"train_loss": train_losses,
               "valid_loss": valid_losses,
               "train_accuracies": train_accuracies,
               "valid_accuracies": valid_accuracies}

    return model, history

    
def main():
    model_name = "moussaKam/AraBART"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
    serialize_data(label_encoder, "../models/label_encoder.pkl")

    for data in dataset:
        dataset[data] = dataset[data].remove_columns(["tweet"])
        dataset[data] = dataset[data].rename_column("label", "labels")
        dataset[data].set_format("torch")

    train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True)
    val_loader = DataLoader(dataset["val"], batch_size=8)
    test_loader = DataLoader(dataset["test"], batch_size=8)

    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
    for param in model.parameters():
        param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU
    for param in model.classification_head.parameters():
        param.requires_grad = True

    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    num_epochs = 100

    model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
    plot_training_history(history)

    test_accuracy = get_model_accuracy(model, test_loader)
    print("The accuracy of the model on the test set is:", test_accuracy)

if __name__ == "__main__":
    main()