import numpy as np from sklearn.metrics import auc, roc_auc_score, precision_recall_curve, average_precision_score def rescale(x): return (x - x.min()) / (x.max() - x.min()) def is_one_class(gt: np.ndarray): gt_ravel = gt.ravel() return gt_ravel.sum() == 0 or gt_ravel.sum() == gt_ravel.shape[0] def calculate_px_metrics(gt_px, pr_px): if is_one_class(gt_px): # In case there are only normal pixels or no pixel-level labels return 0, 0, 0 auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel()) precisions, recalls, _ = precision_recall_curve(gt_px.ravel(), pr_px.ravel()) f1_scores = (2 * precisions * recalls) / (precisions + recalls) f1_px = np.max(f1_scores[np.isfinite(f1_scores)]) ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel()) return auroc_px * 100, f1_px * 100, ap_px * 100 def calculate_im_metrics(gt_im, pr_im): if is_one_class(gt_im): # In case there are only normal samples or no image-level labels return 0, 0, 0 auroc_im = roc_auc_score(gt_im.ravel(), pr_im.ravel()) precisions, recalls, _ = precision_recall_curve(gt_im.ravel(), pr_im.ravel()) f1_scores = (2 * precisions * recalls) / (precisions + recalls) f1_im = np.max(f1_scores[np.isfinite(f1_scores)]) ap_im = average_precision_score(gt_im, pr_im) return ap_im * 100, auroc_im * 100, f1_im * 100 def calculate_average_metric(metrics: dict): average = {} for obj, metric in metrics.items(): for k, v in metric.items(): if k not in average: average[k] = [] average[k].append(v) for k, v in average.items(): average[k] = np.mean(v) return average def calculate_metric(results, obj): gt_px = [] pr_px = [] gt_im = [] pr_im = [] for idx in range(len(results['cls_names'])): if results['cls_names'][idx] == obj: gt_px.append(results['imgs_masks'][idx]) pr_px.append(results['anomaly_maps'][idx]) gt_im.append(results['imgs_gts'][idx]) pr_im.append(results['anomaly_scores'][idx]) gt_px = np.array(gt_px) pr_px = np.array(pr_px) gt_im = np.array(gt_im) pr_im = np.array(pr_im) auroc_px, f1_px, ap_px = calculate_px_metrics(gt_px, pr_px) ap_im, auroc_im, f1_im = calculate_im_metrics(gt_im, pr_im) metric = { 'auroc_px': auroc_px, 'auroc_im': auroc_im, 'f1_px': f1_px, 'f1_im': f1_im, 'ap_px': ap_px, 'ap_im': ap_im, } return metric