Spaces:
Runtime error
Runtime error
File size: 921 Bytes
24615d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
def binary_accuracy(y_pred, y_true):
assert y_true.ndim == 1 and y_true.size() == y_pred.size()
y_prob = torch.sigmoid(y_pred)
y_prob = y_prob > 0.5
return (y_true == y_prob).float().sum().item() / y_true.size(0)
def precision(y_pred, y_true):
assert y_true.ndim == 1 and y_true.size() == y_pred.size()
y_prob = torch.sigmoid(y_pred)
y_prob = y_prob > 0.5
y_positive = y_true >= 1
tp = (y_positive * y_prob).float().sum().item()
n_positive = y_positive.float().sum().item()
if tp == 0: return 0
return tp / n_positive
def recall(y_pred, y_true):
assert y_true.ndim == 1 and y_true.size() == y_pred.size()
y_prob = torch.sigmoid(y_pred)
y_prob = y_prob > 0.5
y_positive = y_true >= 1
tp = (y_positive * y_prob).float().sum().item()
n_pred_positive = y_prob.float().sum().item()
if tp == 0: return 0
return tp / n_pred_positive |