Spaces:
Build error
Build error
import os | |
import glob | |
import torch | |
from os.path import join | |
import numpy as np | |
from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps | |
import sys | |
from ensemble_boxes import * | |
import json | |
import pickle | |
get_file_id = lambda x: x.split('_')[1] | |
get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x] | |
cat_to_idx = {'a':1,'b':2,'c':3,'d':4} | |
def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None): | |
image_dict = dict() | |
if allowed == []: | |
allowed = [i for i in range(len(dataset_paths))] | |
for label in labels: | |
images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths]))) | |
for image in images: | |
if USE_ACR: | |
acr = get_acr_cat(get_file_id(image)) | |
# print(acr, image) | |
key = image[:-4] | |
gts = [] | |
preds = [] | |
for i,dset in enumerate(dataset_paths): | |
if i not in allowed: | |
continue | |
if USE_ACR: | |
if dset.find('AIIMS_C')!=-1: | |
if acr == '0': continue | |
if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1: | |
continue | |
# Now choose dset to be the acr category one | |
dset = dset.replace('/test',f'/test_{acr}') | |
# print('ds',dset) | |
pred_file = join(dset.format(label), key+'.txt') | |
gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt') | |
if label == 'mal': | |
gts.append(file_to_bbox(gt_file)) | |
else: | |
gts.append([]) | |
# TODO: Note this | |
flag = False | |
for mp in mp_dict: | |
if dataset_paths[i].find(mp) != -1: | |
preds.append(mp_dict[mp](file_to_bbox(pred_file))) | |
flag = True | |
break | |
if not flag: | |
preds.append(file_to_bbox(pred_file)) | |
# Ensure all gts are same | |
gt = gts[0] | |
for g in gts[1:]: | |
assert g == gt | |
gt = g | |
# Flatten Preds | |
preds = [np.array(p) for p in preds] | |
preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds] | |
preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds] | |
image_dict[key] = dict() | |
image_dict[key]['gt'] = gts[0] | |
image_dict[key]['preds'] = preds | |
return image_dict | |
def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None): | |
FACTOR = 5000 | |
fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted | |
for key in image_dict: | |
preds = np.array(image_dict[key]['preds']) | |
if len(preds) != 0: | |
boxes_list = [pred[:,1:]/FACTOR for pred in preds] | |
scores_list = [pred[:,0] for pred in preds] | |
labels = [[0. for _ in range(len(p))] for p in preds] | |
if weights is None: | |
weights = [1 for _ in range(len(preds))] | |
if METHOD == 'wbf' and conf_type is not None: | |
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type) | |
else: | |
boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,) | |
preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))] | |
image_dict[key]['preds'] = preds_t | |
return image_dict | |
def manipulate_preds(preds): | |
return preds | |
def manipulate_preds_4(preds): | |
return preds | |
tot = 0 | |
def manipulate_preds_t1(preds): #return manipulate_preds(preds) | |
preds = list(filter(lambda x: x[0]>0.6,preds)) | |
return preds | |
def manipulate_preds_t2(preds): return manipulate_preds_t1(preds) | |
if __name__ == '__main__': | |
USE_ACR = False | |
dataset_paths = [ | |
'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1', | |
'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2', | |
'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3', | |
'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4', | |
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL', | |
'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16', | |
] | |
st = int(sys.argv[1]) | |
end = len(dataset_paths) - int(sys.argv[2]) | |
allowed = [i for i in range(st,end)] | |
allowed = [0,1,2,3,4,5] | |
OUT_FILE = 'contrast_frcnn.txt' | |
if OUT_FILE is not None: | |
fol = os.path.split(OUT_FILE)[0] | |
if fol != '': | |
os.makedirs(fol, exist_ok=True) | |
acr_cat = json.load(open('aiims_categories.json','r')) | |
print(allowed) | |
mp_dict = { | |
'preds_frcnn_AIIMS_C3': manipulate_preds, | |
'preds_frcnn_AIIMS_C4': manipulate_preds_4, | |
'AIIMS_T2': manipulate_preds_t2, | |
'AIIMS_T1': manipulate_preds_t1, | |
} | |
image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict) | |
image_dict = apply_merge(image_dict, METHOD = 'nms') # or wbf | |
if OUT_FILE: | |
pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb')) | |
senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE) | |
pretty_print_fps(senses, fps) | |