File size: 1,476 Bytes
bf1f674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1da2a5a
 
 
bf1f674
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""Reusable metrics functions for evaluating models
"""
import multiprocessing as mp
from typing import List

import torch
from torch.utils.data import DataLoader
from transformers import default_data_collator
from tqdm import tqdm

def get_predictions(
    model: torch.nn.Module,
    dataset: torch.utils.data.Dataset,
) -> List:
    """Compute model predictions for `dataset`.

    Args:
        model (torch.nn.Module): Model to evaluate
        dataset (torch.utils.data.Dataset): Dataset to get predictions for
        return_labels (bool, optional): Whether to return the labels (predictions are always returned).
            Defaults to True.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: 'true_labels', 'pred_labels'
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    loader = DataLoader(
        dataset,
        batch_size=64,
        collate_fn=default_data_collator,
        drop_last=False,
        num_workers=mp.cpu_count(),
    )
    pred_labels = []
    for batch in tqdm(loader):
        inputs = {k: batch[k].to(device) for k in ["attention_mask", "input_ids"]}
        with torch.no_grad():
            outputs = model(**inputs)
        del inputs  # to free up space on GPU
        logits = outputs[0]
        logits = logits.cpu().tolist()
        for i in range(len(logits)):
            pred_labels.append([round(e, 4) for e in logits[i]])
    
    return pred_labels