"""Reusable metrics functions for evaluating models """ import multiprocessing as mp from typing import List import torch from torch.utils.data import DataLoader from transformers import default_data_collator from tqdm import tqdm def get_predictions( model: torch.nn.Module, dataset: torch.utils.data.Dataset, ) -> List: """Compute model predictions for `dataset`. Args: model (torch.nn.Module): Model to evaluate dataset (torch.utils.data.Dataset): Dataset to get predictions for return_labels (bool, optional): Whether to return the labels (predictions are always returned). Defaults to True. Returns: Tuple[torch.Tensor, torch.Tensor]: 'true_labels', 'pred_labels' """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() loader = DataLoader( dataset, batch_size=64, collate_fn=default_data_collator, drop_last=False, num_workers=mp.cpu_count(), ) pred_labels = [] for batch in tqdm(loader): inputs = {k: batch[k].to(device) for k in ["attention_mask", "input_ids"]} with torch.no_grad(): outputs = model(**inputs) del inputs # to free up space on GPU logits = outputs[0] logits = logits.cpu().tolist() for i in range(len(logits)): pred_labels.append([round(e, 4) for e in logits[i]]) return pred_labels