from . import SentenceEvaluator import torch from torch.utils.data import DataLoader import logging from ..util import batch_to_device import os import csv logger = logging.getLogger(__name__) class LabelAccuracyEvaluator(SentenceEvaluator): """ Evaluate a model based on its accuracy on a labeled dataset This requires a model with LossFunction.SOFTMAX The results are written in a CSV. If a CSV already exists, then values are appended. """ def __init__(self, dataloader: DataLoader, name: str = "", softmax_model = None, write_csv: bool = True): """ Constructs an evaluator for the given dataset :param dataloader: the data for the evaluation """ self.dataloader = dataloader self.name = name self.softmax_model = softmax_model if name: name = "_"+name self.write_csv = write_csv self.csv_file = "accuracy_evaluation"+name+"_results.csv" self.csv_headers = ["epoch", "steps", "accuracy"] def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: model.eval() total = 0 correct = 0 if epoch != -1: if steps == -1: out_txt = " after epoch {}:".format(epoch) else: out_txt = " in epoch {} after {} steps:".format(epoch, steps) else: out_txt = ":" logger.info("Evaluation on the "+self.name+" dataset"+out_txt) self.dataloader.collate_fn = model.smart_batching_collate for step, batch in enumerate(self.dataloader): features, label_ids = batch for idx in range(len(features)): features[idx] = batch_to_device(features[idx], model.device) label_ids = label_ids.to(model.device) with torch.no_grad(): _, prediction = self.softmax_model(features, labels=None) total += prediction.size(0) correct += torch.argmax(prediction, dim=1).eq(label_ids).sum().item() accuracy = correct/total logger.info("Accuracy: {:.4f} ({}/{})\n".format(accuracy, correct, total)) if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) if not os.path.isfile(csv_path): with open(csv_path, newline='', mode="w", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(self.csv_headers) writer.writerow([epoch, steps, accuracy]) else: with open(csv_path, newline='', mode="a", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([epoch, steps, accuracy]) return accuracy