arabic-dialect-classifier / src /model_training.py
zaidmehdi's picture
defining traning loop
cdcc3ac
raw
history blame
4.72 kB
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import DatasetDict, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils import get_dataset
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
train_losses = []
valid_losses = []
train_accuracies = []
valid_accuracies = []
best_valid_loss = float("inf")
epochs_no_improve = 0
best_model = None
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
correct_train = 0
total_train = 0
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
train_loss += loss.item()
_, predicted_train = torch.max(outputs.logits, 1)
labels_train = batch["labels"]
correct_train += (predicted_train == labels_train).sum().item()
total_train += labels_train.size(0)
train_accuracy = correct_train / total_train
train_losses.append(train_loss / len(train_loader))
train_accuracies.append(train_accuracy)
model.eval()
valid_loss = 0.0
correct_valid = 0
total_valid = 0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
valid_loss += loss.item()
_, predicted_valid = torch.max(outputs.logits, 1)
labels_valid = batch["labels"]
correct_valid += (predicted_valid == labels_valid).sum().item()
total_valid += labels_valid.size(0)
valid_loss /= len(val_loader)
valid_losses.append(valid_loss)
valid_accuracy = correct_valid / total_valid
valid_accuracies.append(valid_accuracy)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
epochs_no_improve = 0
best_model = model.state_dict()
torch.save(best_model, "best_model_checkpoint.pth")
else:
epochs_no_improve += 1
if epochs_no_improve == patience:
print(f'Early stopping after {epoch+1} epochs with no improvement.')
break
model.load_state_dict(best_model)
history = {"train_loss": train_losses,
"valid_loss": valid_losses,
"train_accuracies": train_accuracies,
"valid_accuracies": valid_accuracies}
return model, history
def main():
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
for data in dataset:
dataset[data] = dataset[data].remove_columns(["tweet"])
dataset[data] = dataset[data].rename_column("label", "labels")
dataset[data].set_format("torch")
train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True)
val_loader = DataLoader(dataset["val"], batch_size=8)
test_loader = DataLoader(dataset["test"], batch_size=8)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
for param in model.parameters():
param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU
for param in model.classification_head.parameters():
param.requires_grad = True
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 100
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
if __name__ == "__main__":
main()