File size: 1,476 Bytes
bf1f674 1da2a5a bf1f674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
"""Reusable metrics functions for evaluating models
"""
import multiprocessing as mp
from typing import List
import torch
from torch.utils.data import DataLoader
from transformers import default_data_collator
from tqdm import tqdm
def get_predictions(
model: torch.nn.Module,
dataset: torch.utils.data.Dataset,
) -> List:
"""Compute model predictions for `dataset`.
Args:
model (torch.nn.Module): Model to evaluate
dataset (torch.utils.data.Dataset): Dataset to get predictions for
return_labels (bool, optional): Whether to return the labels (predictions are always returned).
Defaults to True.
Returns:
Tuple[torch.Tensor, torch.Tensor]: 'true_labels', 'pred_labels'
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
loader = DataLoader(
dataset,
batch_size=64,
collate_fn=default_data_collator,
drop_last=False,
num_workers=mp.cpu_count(),
)
pred_labels = []
for batch in tqdm(loader):
inputs = {k: batch[k].to(device) for k in ["attention_mask", "input_ids"]}
with torch.no_grad():
outputs = model(**inputs)
del inputs # to free up space on GPU
logits = outputs[0]
logits = logits.cpu().tolist()
for i in range(len(logits)):
pred_labels.append([round(e, 4) for e in logits[i]])
return pred_labels |