Spaces:
Build error
Build error
import os | |
import glob | |
import sys | |
from os.path import join | |
''' | |
Note: Anywhere empty boxes means [] and not [[]] | |
''' | |
def remove_true_positives(gts, preds): | |
def true_positive(gt, pred): | |
# If center of pred is inside the gt, it is a true positive | |
c_pred = ((pred[0]+pred[2])/2., (pred[1]+pred[3])/2.) | |
if (c_pred[0] >= gt[0] and c_pred[0] <= gt[2] and | |
c_pred[1] >= gt[1] and c_pred[1] <= gt[3]): | |
return True | |
return False | |
tps = 0 | |
fns = 0 | |
for gt in gts: | |
# First check if any true positive exists | |
# If more than one exists, do not include it in next set of preds | |
add_tp = False | |
new_preds = [] | |
for pred in preds: | |
if true_positive(gt, pred): | |
add_tp = True | |
else: | |
new_preds.append(pred) | |
preds = new_preds | |
if add_tp: | |
tps += 1 | |
else: | |
fns += 1 | |
return preds, tps, fns | |
def calc_metric_single(gts, preds, threshold,): | |
''' | |
Returns fp, tp, tn, fn | |
''' | |
preds = list(filter(lambda x: x[0] >= threshold, preds)) | |
preds = [pred[1:] for pred in preds] # Remove the scores | |
if len(gts) == 0: | |
return len(preds), 0, 1 if len(preds) == 0 else 0, 0 | |
preds, tps, fns = remove_true_positives(gts, preds) | |
# All remaining will have to fps | |
fps = len(preds) | |
return fps, tps, 0, fns | |
def calc_metrics_at_thresh(im_dict, threshold): | |
''' | |
Returns fp, tp, tn, fn | |
''' | |
fps, tps, tns, fns = 0, 0, 0, 0 | |
for key in im_dict: | |
fp,tp,tn,fn = calc_metric_single(im_dict[key]['gt'], | |
im_dict[key]['preds'], threshold) | |
fps+=fp | |
tps+=tp | |
tns+=tn | |
fns+=fn | |
return fps, tps, tns, fns | |
from joblib import Parallel, delayed | |
def calc_metrics(inp): | |
im_dict, tr = inp | |
out = dict() | |
for t in tr: | |
fp, tp, tn, fn = calc_metrics_at_thresh(im_dict, t) | |
out[t] = [fp, tp, tn, fn] | |
return out | |
def calc_froc_from_dict(im_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3], save_to = None): | |
num_images = len(im_dict) | |
gap = 0.005 | |
n = int(1/gap) | |
thresholds = [i * gap for i in range(n)] | |
fps = [0 for _ in range(n)] | |
tps = [0 for _ in range(n)] | |
tns = [0 for _ in range(n)] | |
fns = [0 for _ in range(n)] | |
for i,t in enumerate(thresholds): | |
fps[i], tps[i], tns[i], fns[i] = calc_metrics_at_thresh(im_dict, t) | |
# Now calculate the sensitivities | |
senses = [] | |
for t,f in zip(tps, fns): | |
try: senses.append(t/(t+f)) | |
except: senses.append(0.) | |
if save_to is not None: | |
f = open(save_to, 'w') | |
for fp,s in zip(fps, senses): | |
f.write(f'{fp/num_images} {s}\n') | |
f.close() | |
senses_req = [] | |
for fp_req in fps_req: | |
for i,f in enumerate(fps): | |
if f/num_images < fp_req: | |
if fp_req == 0.1: | |
print(fps[i], tps[i], tns[i], fns[i]) | |
prec = tps[i]/(tps[i] + fps[i]) | |
recall = tps[i]/(tps[i] + fns[i]) | |
f1 = 2*prec*recall/(prec+recall) | |
spec = tns[i]/ (tns[i] + fps[i]) | |
print(f'Specificity: {spec}') | |
print(f'Precision: {prec}') | |
print(f'Recall: {recall}') | |
print(f'F1: {f1}') | |
senses_req.append(senses[i-1]) | |
break | |
return senses_req, fps_req | |
def file_to_bbox(file_name): | |
try: | |
content = open(file_name, 'r').readlines() | |
st = 0 | |
if len(content) == 0: | |
# Empty File Should Return [] | |
return [] | |
if content[0].split()[0].isalpha(): | |
st = 1 | |
return [[float(x) for x in line.split()[st:]] for line in content] | |
except FileNotFoundError: | |
print(f'No Corresponding Box Found for file {file_name}, using [] as preds') | |
return [] | |
except Exception as e: | |
print('Some Error',e) | |
return [] | |
def generate_image_dict(preds_folder_name='preds_42', | |
root_fol='/home/pranjal/densebreeast_datasets/AIIMS_C1', | |
mal_path=None, ben_path=None, gt_path=None, | |
mal_img_path = None, ben_img_path = None | |
): | |
mal_path = join(root_fol, mal_path) if mal_path else join( | |
root_fol, 'mal', preds_folder_name) | |
ben_path = join(root_fol, ben_path) if ben_path else join( | |
root_fol, 'ben', preds_folder_name) | |
mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join( | |
root_fol, 'mal', 'images') | |
ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join( | |
root_fol, 'ben', 'images') | |
gt_path = join(root_fol, gt_path) if gt_path else join( | |
root_fol, 'mal', 'gt') | |
''' | |
image_dict structure: | |
'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : [[]]} | |
''' | |
image_dict = dict() | |
# GT Might be sightly different from images, therefore we will index gts based on | |
# the images folder instead. | |
for file in os.listdir(mal_img_path): | |
if not file.endswith('.png'): | |
continue | |
file = file[:-4] + '.txt' | |
file = join(gt_path, file) | |
key = os.path.split(file)[-1][:-4] | |
image_dict[key] = dict() | |
image_dict[key]['gt'] = file_to_bbox(file) | |
image_dict[key]['preds'] = [] | |
for file in glob.glob(join(mal_path, '*.txt')): | |
key = os.path.split(file)[-1][:-4] | |
assert key in image_dict | |
image_dict[key]['preds'] = file_to_bbox(file) | |
for file in os.listdir(ben_img_path): | |
if not file.endswith('.png'): | |
continue | |
file = file[:-4] + '.txt' | |
file = join(ben_path, file) | |
key = os.path.split(file)[-1][:-4] | |
if key == 'Calc-Test_P_00353_LEFT_CC' or key == 'Calc-Training_P_00600_LEFT_CC': # Corrupt Files in Dataset | |
continue | |
if key in image_dict: | |
print(key) | |
# assert key not in image_dict | |
if key in image_dict: | |
print(f'Unexpected Error. {key} exists in multiple splits') | |
continue | |
image_dict[key] = dict() | |
image_dict[key]['preds'] = file_to_bbox(file) | |
image_dict[key]['gt'] = [] | |
return image_dict | |
def pretty_print_fps(senses,fps): | |
for s,f in zip(senses,fps): | |
print(f'Sensitivty at {f}: {s}') | |
def get_froc_points(preds_image_folder, root_fol, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3], save_to = None): | |
im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol) | |
# print(im_dict) | |
print(len(im_dict)) | |
senses, fps = calc_froc_from_dict(im_dict, fps_req, save_to = save_to) | |
return senses, fps | |
if __name__ == '__main__': | |
seed = '42' if len(sys.argv)== 1 else sys.argv[1] | |
root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test_2' | |
if len(sys.argv) <= 2: | |
save_to = None | |
else: | |
save_to = sys.argv[2] | |
senses, fps = get_froc_points(f'preds_{seed}',root_fol, save_to = save_to) | |
pretty_print_fps(senses, fps) | |