|
"""Reusable metrics functions for evaluating models |
|
""" |
|
import multiprocessing as mp |
|
from typing import List |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from transformers import default_data_collator |
|
from tqdm import tqdm |
|
|
|
def get_predictions( |
|
model: torch.nn.Module, |
|
dataset: torch.utils.data.Dataset, |
|
) -> List: |
|
"""Compute model predictions for `dataset`. |
|
|
|
Args: |
|
model (torch.nn.Module): Model to evaluate |
|
dataset (torch.utils.data.Dataset): Dataset to get predictions for |
|
return_labels (bool, optional): Whether to return the labels (predictions are always returned). |
|
Defaults to True. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: 'true_labels', 'pred_labels' |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=64, |
|
collate_fn=default_data_collator, |
|
drop_last=False, |
|
num_workers=mp.cpu_count(), |
|
) |
|
pred_labels = [] |
|
for batch in tqdm(loader): |
|
inputs = {k: batch[k].to(device) for k in ["attention_mask", "input_ids"]} |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
del inputs |
|
logits = outputs[0] |
|
logits = logits.cpu().tolist() |
|
for i in range(len(logits)): |
|
pred_labels.append([round(e, 4) for e in logits[i]]) |
|
|
|
return pred_labels |