File size: 2,833 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from . import SentenceEvaluator
import torch
from torch.utils.data import DataLoader
import logging
from ..util import batch_to_device
import os
import csv
logger = logging.getLogger(__name__)
class LabelAccuracyEvaluator(SentenceEvaluator):
"""
Evaluate a model based on its accuracy on a labeled dataset
This requires a model with LossFunction.SOFTMAX
The results are written in a CSV. If a CSV already exists, then values are appended.
"""
def __init__(self, dataloader: DataLoader, name: str = "", softmax_model = None, write_csv: bool = True):
"""
Constructs an evaluator for the given dataset
:param dataloader:
the data for the evaluation
"""
self.dataloader = dataloader
self.name = name
self.softmax_model = softmax_model
if name:
name = "_"+name
self.write_csv = write_csv
self.csv_file = "accuracy_evaluation"+name+"_results.csv"
self.csv_headers = ["epoch", "steps", "accuracy"]
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
model.eval()
total = 0
correct = 0
if epoch != -1:
if steps == -1:
out_txt = " after epoch {}:".format(epoch)
else:
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logger.info("Evaluation on the "+self.name+" dataset"+out_txt)
self.dataloader.collate_fn = model.smart_batching_collate
for step, batch in enumerate(self.dataloader):
features, label_ids = batch
for idx in range(len(features)):
features[idx] = batch_to_device(features[idx], model.device)
label_ids = label_ids.to(model.device)
with torch.no_grad():
_, prediction = self.softmax_model(features, labels=None)
total += prediction.size(0)
correct += torch.argmax(prediction, dim=1).eq(label_ids).sum().item()
accuracy = correct/total
logger.info("Accuracy: {:.4f} ({}/{})\n".format(accuracy, correct, total))
if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
if not os.path.isfile(csv_path):
with open(csv_path, newline='', mode="w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(self.csv_headers)
writer.writerow([epoch, steps, accuracy])
else:
with open(csv_path, newline='', mode="a", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([epoch, steps, accuracy])
return accuracy
|