Spaces:
Sleeping
Sleeping
import torch | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from utils import get_dataset, serialize_data, plot_training_history, get_model_accuracy | |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10): | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model.to(device) | |
num_training_steps = num_epochs * len(train_loader) | |
progress_bar = tqdm(range(num_training_steps)) | |
train_losses = [] | |
valid_losses = [] | |
train_accuracies = [] | |
valid_accuracies = [] | |
best_valid_loss = float("inf") | |
epochs_no_improve = 0 | |
best_model = None | |
for epoch in range(num_epochs): | |
model.train() | |
train_loss = 0.0 | |
correct_train = 0 | |
total_train = 0 | |
for batch in train_loader: | |
batch = {k: v.to(device) for k, v in batch.items()} | |
outputs = model(**batch) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
train_loss += loss.item() | |
_, predicted_train = torch.max(outputs.logits, 1) | |
labels_train = batch["labels"] | |
correct_train += (predicted_train == labels_train).sum().item() | |
total_train += labels_train.size(0) | |
train_accuracy = correct_train / total_train | |
train_losses.append(train_loss / len(train_loader)) | |
train_accuracies.append(train_accuracy) | |
model.eval() | |
valid_loss = 0.0 | |
correct_valid = 0 | |
total_valid = 0 | |
with torch.no_grad(): | |
for batch in val_loader: | |
batch = {k: v.to(device) for k, v in batch.items()} | |
outputs = model(**batch) | |
loss = outputs.loss | |
valid_loss += loss.item() | |
_, predicted_valid = torch.max(outputs.logits, 1) | |
labels_valid = batch["labels"] | |
correct_valid += (predicted_valid == labels_valid).sum().item() | |
total_valid += labels_valid.size(0) | |
valid_loss /= len(val_loader) | |
valid_losses.append(valid_loss) | |
valid_accuracy = correct_valid / total_valid | |
valid_accuracies.append(valid_accuracy) | |
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}') | |
if valid_loss < best_valid_loss: | |
best_valid_loss = valid_loss | |
epochs_no_improve = 0 | |
best_model = model.state_dict() | |
torch.save(best_model, "../models/best_model_checkpoint.pth") | |
else: | |
epochs_no_improve += 1 | |
if epochs_no_improve == patience: | |
print(f'Early stopping after {epoch+1} epochs with no improvement.') | |
break | |
model.load_state_dict(best_model) | |
history = {"train_loss": train_losses, | |
"valid_loss": valid_losses, | |
"train_accuracies": train_accuracies, | |
"valid_accuracies": valid_accuracies} | |
return model, history | |
def main(): | |
model_name = "moussaKam/AraBART" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer) | |
serialize_data(label_encoder, "../models/label_encoder.pkl") | |
for data in dataset: | |
dataset[data] = dataset[data].remove_columns(["tweet"]) | |
dataset[data] = dataset[data].rename_column("label", "labels") | |
dataset[data].set_format("torch") | |
train_loader = DataLoader(dataset["train"], batch_size=8, shuffle=True) | |
val_loader = DataLoader(dataset["val"], batch_size=8) | |
test_loader = DataLoader(dataset["test"], batch_size=8) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21) | |
for param in model.parameters(): | |
param.requires_grad = False # We don't retrain the pretrained model due to lack of GPU | |
for param in model.classification_head.parameters(): | |
param.requires_grad = True | |
optimizer = optim.Adam(model.parameters(), lr=0.0001) | |
num_epochs = 100 | |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs) | |
plot_training_history(history) | |
test_accuracy = get_model_accuracy(model, test_loader) | |
print("The accuracy of the model on the test set is:", test_accuracy) | |
if __name__ == "__main__": | |
main() |