File size: 921 Bytes
24615d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch

def binary_accuracy(y_pred, y_true):
    assert y_true.ndim == 1 and y_true.size() == y_pred.size()
    y_prob = torch.sigmoid(y_pred)
    y_prob = y_prob > 0.5
    return (y_true == y_prob).float().sum().item() / y_true.size(0)

def precision(y_pred, y_true):
    assert y_true.ndim == 1 and y_true.size() == y_pred.size()
    y_prob = torch.sigmoid(y_pred)
    y_prob = y_prob > 0.5
    y_positive = y_true >= 1
    tp = (y_positive * y_prob).float().sum().item()
    n_positive = y_positive.float().sum().item()
    if tp == 0: return 0
    return tp / n_positive

def recall(y_pred, y_true):
    assert y_true.ndim == 1 and y_true.size() == y_pred.size()
    y_prob = torch.sigmoid(y_pred)
    y_prob = y_prob > 0.5
    y_positive = y_true >= 1
    tp = (y_positive * y_prob).float().sum().item()
    n_pred_positive = y_prob.float().sum().item()
    if tp == 0: return 0
    return tp / n_pred_positive