arabic-dialect-classifier / src /model_training.py
zaidmehdi's picture
print test accuracu
1964ece
raw
history blame
4.77 kB
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils import get_dataset, serialize_data, plot_training_history, get_model_accuracy
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
train_losses = []
valid_losses = []
train_accuracies = []
valid_accuracies = []
best_valid_loss = float("inf")
epochs_no_improve = 0
best_model = None
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
correct_train = 0
total_train = 0
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
train_loss += loss.item()
_, predicted_train = torch.max(outputs.logits, 1)
labels_train = batch["labels"]
correct_train += (predicted_train == labels_train).sum().item()
total_train += labels_train.size(0)
train_accuracy = correct_train / total_train
train_losses.append(train_loss / len(train_loader))
train_accuracies.append(train_accuracy)
model.eval()
valid_loss = 0.0
correct_valid = 0
total_valid = 0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
valid_loss += loss.item()
_, predicted_valid = torch.max(outputs.logits, 1)
labels_valid = batch["labels"]
correct_valid += (predicted_valid == labels_valid).sum().item()
total_valid += labels_valid.size(0)
valid_loss /= len(val_loader)
valid_losses.append(valid_loss)
valid_accuracy = correct_valid / total_valid
valid_accuracies.append(valid_accuracy)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
epochs_no_improve = 0
best_model = model.state_dict()
torch.save(best_model, "../models/best_model_checkpoint.pth")
else:
epochs_no_improve += 1
if epochs_no_improve == patience:
print(f'Early stopping after {epoch+1} epochs with no improvement.')
break
model.load_state_dict(best_model)
history = {"train_loss": train_losses,
"valid_loss": valid_losses,
"train_accuracies": train_accuracies,
"valid_accuracies": valid_accuracies}
return model, history
def main():
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
serialize_data(label_encoder, "../models/label_encoder.pkl")
for data in dataset:
dataset[data] = dataset[data].remove_columns(["tweet"])
dataset[data] = dataset[data].rename_column("label", "labels")
dataset[data].set_format("torch")
train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True)
val_loader = DataLoader(dataset["val"], batch_size=8)
test_loader = DataLoader(dataset["test"], batch_size=8)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
for param in model.parameters():
param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU
for param in model.classification_head.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 100
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
plot_training_history(history)
test_accuracy = get_model_accuracy(model, test_loader)
print("The accuracy of the model on the test set is:", test_accuracy)
if __name__ == "__main__":
main()