Spaces:
Sleeping
Sleeping
defining traning loop
Browse files- src/model_training.py +115 -20
src/model_training.py
CHANGED
@@ -1,36 +1,131 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch.nn as nn
|
2 |
-
import torch.
|
3 |
-
from
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
from utils import get_dataset
|
6 |
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
self.model = AutoModel.from_pretrained(model_name, config=config)
|
12 |
-
self.classification_head = nn.Linear(config.hidden_size, num_labels)
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def main():
|
24 |
model_name = "moussaKam/AraBART"
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
26 |
-
config = AutoConfig.from_pretrained(model_name)
|
27 |
|
28 |
-
dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
main()
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from datasets import DatasetDict, Dataset
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from sklearn.preprocessing import LabelEncoder
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
|
14 |
from utils import get_dataset
|
15 |
|
16 |
|
17 |
+
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
18 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
19 |
+
model.to(device)
|
|
|
|
|
20 |
|
21 |
+
num_training_steps = num_epochs * len(train_loader)
|
22 |
+
progress_bar = tqdm(range(num_training_steps))
|
23 |
+
|
24 |
+
train_losses = []
|
25 |
+
valid_losses = []
|
26 |
+
train_accuracies = []
|
27 |
+
valid_accuracies = []
|
28 |
+
|
29 |
+
best_valid_loss = float("inf")
|
30 |
+
epochs_no_improve = 0
|
31 |
+
best_model = None
|
32 |
+
|
33 |
+
for epoch in range(num_epochs):
|
34 |
+
model.train()
|
35 |
+
train_loss = 0.0
|
36 |
+
correct_train = 0
|
37 |
+
total_train = 0
|
38 |
+
for batch in train_loader:
|
39 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
40 |
+
outputs = model(**batch)
|
41 |
+
loss = outputs.loss
|
42 |
+
loss.backward()
|
43 |
+
|
44 |
+
optimizer.step()
|
45 |
+
optimizer.zero_grad()
|
46 |
+
progress_bar.update(1)
|
47 |
+
|
48 |
+
train_loss += loss.item()
|
49 |
+
|
50 |
+
_, predicted_train = torch.max(outputs.logits, 1)
|
51 |
+
labels_train = batch["labels"]
|
52 |
+
correct_train += (predicted_train == labels_train).sum().item()
|
53 |
+
total_train += labels_train.size(0)
|
54 |
|
55 |
+
train_accuracy = correct_train / total_train
|
56 |
+
train_losses.append(train_loss / len(train_loader))
|
57 |
+
train_accuracies.append(train_accuracy)
|
58 |
+
|
59 |
+
model.eval()
|
60 |
+
valid_loss = 0.0
|
61 |
+
correct_valid = 0
|
62 |
+
total_valid = 0
|
63 |
+
with torch.no_grad():
|
64 |
+
for batch in val_loader:
|
65 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
66 |
+
outputs = model(**batch)
|
67 |
+
loss = outputs.loss
|
68 |
+
valid_loss += loss.item()
|
69 |
+
|
70 |
+
_, predicted_valid = torch.max(outputs.logits, 1)
|
71 |
+
labels_valid = batch["labels"]
|
72 |
+
correct_valid += (predicted_valid == labels_valid).sum().item()
|
73 |
+
total_valid += labels_valid.size(0)
|
74 |
+
|
75 |
+
valid_loss /= len(val_loader)
|
76 |
+
valid_losses.append(valid_loss)
|
77 |
+
|
78 |
+
valid_accuracy = correct_valid / total_valid
|
79 |
+
valid_accuracies.append(valid_accuracy)
|
80 |
+
|
81 |
+
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')
|
82 |
|
83 |
+
if valid_loss < best_valid_loss:
|
84 |
+
best_valid_loss = valid_loss
|
85 |
+
epochs_no_improve = 0
|
86 |
+
best_model = model.state_dict()
|
87 |
+
torch.save(best_model, "best_model_checkpoint.pth")
|
88 |
+
else:
|
89 |
+
epochs_no_improve += 1
|
90 |
+
if epochs_no_improve == patience:
|
91 |
+
print(f'Early stopping after {epoch+1} epochs with no improvement.')
|
92 |
+
break
|
93 |
+
|
94 |
+
model.load_state_dict(best_model)
|
95 |
+
history = {"train_loss": train_losses,
|
96 |
+
"valid_loss": valid_losses,
|
97 |
+
"train_accuracies": train_accuracies,
|
98 |
+
"valid_accuracies": valid_accuracies}
|
99 |
+
|
100 |
+
return model, history
|
101 |
+
|
102 |
+
|
103 |
def main():
|
104 |
model_name = "moussaKam/AraBART"
|
105 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
106 |
|
107 |
+
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
|
108 |
+
|
109 |
+
for data in dataset:
|
110 |
+
dataset[data] = dataset[data].remove_columns(["tweet"])
|
111 |
+
dataset[data] = dataset[data].rename_column("label", "labels")
|
112 |
+
dataset[data].set_format("torch")
|
113 |
|
114 |
+
train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True)
|
115 |
+
val_loader = DataLoader(dataset["val"], batch_size=8)
|
116 |
+
test_loader = DataLoader(dataset["test"], batch_size=8)
|
117 |
+
|
118 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
|
119 |
+
for param in model.parameters():
|
120 |
+
param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU
|
121 |
+
for param in model.classification_head.parameters():
|
122 |
+
param.requires_grad = True
|
123 |
+
|
124 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
|
125 |
+
num_epochs = 100
|
126 |
+
|
127 |
+
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
128 |
+
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
main()
|