zaidmehdi commited on
Commit
cdcc3ac
·
1 Parent(s): ff82938

defining traning loop

Browse files
Files changed (1) hide show
  1. src/model_training.py +115 -20
src/model_training.py CHANGED
@@ -1,36 +1,131 @@
 
 
 
 
1
  import torch.nn as nn
2
- import torch.nn.functional as F
3
- from transformers import AutoTokenizer, AutoModel, AutoConfig
 
 
 
 
 
4
 
5
  from utils import get_dataset
6
 
7
 
8
- class Model(nn.Module):
9
- def __init__(self, model_name, config, num_labels):
10
- super().__init__()
11
- self.model = AutoModel.from_pretrained(model_name, config=config)
12
- self.classification_head = nn.Linear(config.hidden_size, num_labels)
13
 
14
- def forward(self, input_ids):
15
- outputs = self.model(input_ids)
16
- pooled_output = outputs.last_hidden_state[:, 0]
17
- logits = self.classification_head(pooled_output)
18
- probabilities = F.softmax(logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- return probabilities
21
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def main():
24
  model_name = "moussaKam/AraBART"
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
- config = AutoConfig.from_pretrained(model_name)
27
 
28
- dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
29
- num_labels = len(set(dataset["train"]["label"]))
30
- model = Model(model_name, config, num_labels)
31
-
32
- print(dataset["train"])
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  main()
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
  import torch.nn as nn
6
+ import torch.optim as optim
7
+ from datasets import DatasetDict, Dataset
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.preprocessing import LabelEncoder
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
  from utils import get_dataset
15
 
16
 
17
+ def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
18
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
19
+ model.to(device)
 
 
20
 
21
+ num_training_steps = num_epochs * len(train_loader)
22
+ progress_bar = tqdm(range(num_training_steps))
23
+
24
+ train_losses = []
25
+ valid_losses = []
26
+ train_accuracies = []
27
+ valid_accuracies = []
28
+
29
+ best_valid_loss = float("inf")
30
+ epochs_no_improve = 0
31
+ best_model = None
32
+
33
+ for epoch in range(num_epochs):
34
+ model.train()
35
+ train_loss = 0.0
36
+ correct_train = 0
37
+ total_train = 0
38
+ for batch in train_loader:
39
+ batch = {k: v.to(device) for k, v in batch.items()}
40
+ outputs = model(**batch)
41
+ loss = outputs.loss
42
+ loss.backward()
43
+
44
+ optimizer.step()
45
+ optimizer.zero_grad()
46
+ progress_bar.update(1)
47
+
48
+ train_loss += loss.item()
49
+
50
+ _, predicted_train = torch.max(outputs.logits, 1)
51
+ labels_train = batch["labels"]
52
+ correct_train += (predicted_train == labels_train).sum().item()
53
+ total_train += labels_train.size(0)
54
 
55
+ train_accuracy = correct_train / total_train
56
+ train_losses.append(train_loss / len(train_loader))
57
+ train_accuracies.append(train_accuracy)
58
+
59
+ model.eval()
60
+ valid_loss = 0.0
61
+ correct_valid = 0
62
+ total_valid = 0
63
+ with torch.no_grad():
64
+ for batch in val_loader:
65
+ batch = {k: v.to(device) for k, v in batch.items()}
66
+ outputs = model(**batch)
67
+ loss = outputs.loss
68
+ valid_loss += loss.item()
69
+
70
+ _, predicted_valid = torch.max(outputs.logits, 1)
71
+ labels_valid = batch["labels"]
72
+ correct_valid += (predicted_valid == labels_valid).sum().item()
73
+ total_valid += labels_valid.size(0)
74
+
75
+ valid_loss /= len(val_loader)
76
+ valid_losses.append(valid_loss)
77
+
78
+ valid_accuracy = correct_valid / total_valid
79
+ valid_accuracies.append(valid_accuracy)
80
+
81
+ print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')
82
 
83
+ if valid_loss < best_valid_loss:
84
+ best_valid_loss = valid_loss
85
+ epochs_no_improve = 0
86
+ best_model = model.state_dict()
87
+ torch.save(best_model, "best_model_checkpoint.pth")
88
+ else:
89
+ epochs_no_improve += 1
90
+ if epochs_no_improve == patience:
91
+ print(f'Early stopping after {epoch+1} epochs with no improvement.')
92
+ break
93
+
94
+ model.load_state_dict(best_model)
95
+ history = {"train_loss": train_losses,
96
+ "valid_loss": valid_losses,
97
+ "train_accuracies": train_accuracies,
98
+ "valid_accuracies": valid_accuracies}
99
+
100
+ return model, history
101
+
102
+
103
  def main():
104
  model_name = "moussaKam/AraBART"
105
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
106
 
107
+ dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
108
+
109
+ for data in dataset:
110
+ dataset[data] = dataset[data].remove_columns(["tweet"])
111
+ dataset[data] = dataset[data].rename_column("label", "labels")
112
+ dataset[data].set_format("torch")
113
 
114
+ train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True)
115
+ val_loader = DataLoader(dataset["val"], batch_size=8)
116
+ test_loader = DataLoader(dataset["test"], batch_size=8)
117
+
118
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
119
+ for param in model.parameters():
120
+ param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU
121
+ for param in model.classification_head.parameters():
122
+ param.requires_grad = True
123
+
124
+ optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
125
+ num_epochs = 100
126
+
127
+ model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
128
+
129
 
130
  if __name__ == "__main__":
131
  main()