txt2audio's picture
update
fa25a07
import logging
import numpy as np
import scipy
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
logger = logging.getLogger(f'main.{__name__}')
def metrics(targets, outputs, topk=(1, 5)):
"""
Adapted from https://github.com/hche11/VGGSound/blob/master/utils.py
Calculate statistics including mAP, AUC, and d-prime.
Args:
output: 2d tensors, (dataset_size, classes_num) - before softmax
target: 1d tensors, (dataset_size, )
topk: tuple
Returns:
metric_dict: a dict of metrics
"""
metrics_dict = dict()
num_cls = outputs.shape[-1]
# accuracy@k
_, preds = torch.topk(outputs, k=max(topk), dim=1)
correct_for_maxtopk = preds == targets.view(-1, 1).expand_as(preds)
for k in topk:
metrics_dict[f'accuracy_{k}'] = float(correct_for_maxtopk[:, :k].sum() / correct_for_maxtopk.shape[0])
# avg precision, average roc_auc, and dprime
targets = torch.nn.functional.one_hot(targets, num_classes=num_cls)
# ids of the predicted classes (same as softmax)
targets_pred = torch.softmax(outputs, dim=1)
targets = targets.numpy()
targets_pred = targets_pred.numpy()
# one-vs-rest
avg_p = [average_precision_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)]
try:
roc_aucs = [roc_auc_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)]
except ValueError:
logger.warning('Weird... Some classes never occured in targets. Do not trust the metrics.')
roc_aucs = np.array([0.5])
avg_p = np.array([0])
metrics_dict['mAP'] = np.mean(avg_p)
metrics_dict['mROCAUC'] = np.mean(roc_aucs)
# Percent point function (ppf) (inverse of cdf — percentiles).
metrics_dict['dprime'] = scipy.stats.norm().ppf(metrics_dict['mROCAUC']) * np.sqrt(2)
return metrics_dict
if __name__ == '__main__':
targets = torch.tensor([3, 3, 1, 2, 1, 0])
outputs = torch.tensor([
[1.2, 1.3, 1.1, 1.5],
[1.3, 1.4, 1.0, 1.1],
[1.5, 1.1, 1.4, 1.3],
[1.0, 1.2, 1.4, 1.5],
[1.2, 1.3, 1.1, 1.1],
[1.2, 1.1, 1.1, 1.1],
]).float()
metrics_dict = metrics(targets, outputs, topk=(1, 3))
print(metrics_dict)