VRIS_vip / my_datasets /refexp_eval.py
dianecy's picture
Add files using upload-large-folder tool
3b5fc39 verified
raw
history blame
3.25 kB
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
import copy
from collections import defaultdict
from pathlib import Path
import torch
import torch.utils.data
import util.misc as utils
from util.box_ops import generalized_box_iou
class RefExpEvaluator(object):
def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5):
assert isinstance(k, (list, tuple))
refexp_gt = copy.deepcopy(refexp_gt)
self.refexp_gt = refexp_gt
self.iou_types = iou_types
self.img_ids = self.refexp_gt.imgs.keys()
self.predictions = {}
self.k = k
self.thresh_iou = thresh_iou
def accumulate(self):
pass
def update(self, predictions):
self.predictions.update(predictions)
def synchronize_between_processes(self):
all_predictions = utils.all_gather(self.predictions)
merged_predictions = {}
for p in all_predictions:
merged_predictions.update(p)
self.predictions = merged_predictions
def summarize(self):
if utils.is_main_process():
dataset2score = {
"refcoco": {k: 0.0 for k in self.k},
"refcoco+": {k: 0.0 for k in self.k},
"refcocog": {k: 0.0 for k in self.k},
}
dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0}
for image_id in self.img_ids:
ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id)
assert len(ann_ids) == 1
img_info = self.refexp_gt.loadImgs(image_id)[0]
target = self.refexp_gt.loadAnns(ann_ids[0])
prediction = self.predictions[image_id]
assert prediction is not None
sorted_scores_boxes = sorted(
zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True
)
sorted_scores, sorted_boxes = zip(*sorted_scores_boxes)
sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes])
target_bbox = target[0]["bbox"]
converted_bbox = [
target_bbox[0],
target_bbox[1],
target_bbox[2] + target_bbox[0],
target_bbox[3] + target_bbox[1],
]
giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4))
for k in self.k:
if max(giou[:k]) >= self.thresh_iou:
dataset2score[img_info["dataset_name"]][k] += 1.0
dataset2count[img_info["dataset_name"]] += 1.0
for key, value in dataset2score.items():
for k in self.k:
try:
value[k] /= dataset2count[key]
except:
pass
results = {}
for key, value in dataset2score.items():
results[key] = sorted([v for k, v in value.items()])
print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n")
return results
return None