|
|
|
import copy |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.utils.data |
|
|
|
import util.misc as utils |
|
from util.box_ops import generalized_box_iou |
|
|
|
|
|
class RefExpEvaluator(object): |
|
def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5): |
|
assert isinstance(k, (list, tuple)) |
|
refexp_gt = copy.deepcopy(refexp_gt) |
|
self.refexp_gt = refexp_gt |
|
self.iou_types = iou_types |
|
self.img_ids = self.refexp_gt.imgs.keys() |
|
self.predictions = {} |
|
self.k = k |
|
self.thresh_iou = thresh_iou |
|
|
|
def accumulate(self): |
|
pass |
|
|
|
def update(self, predictions): |
|
self.predictions.update(predictions) |
|
|
|
def synchronize_between_processes(self): |
|
all_predictions = utils.all_gather(self.predictions) |
|
merged_predictions = {} |
|
for p in all_predictions: |
|
merged_predictions.update(p) |
|
self.predictions = merged_predictions |
|
|
|
def summarize(self): |
|
if utils.is_main_process(): |
|
dataset2score = { |
|
"refcoco": {k: 0.0 for k in self.k}, |
|
"refcoco+": {k: 0.0 for k in self.k}, |
|
"refcocog": {k: 0.0 for k in self.k}, |
|
} |
|
dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0} |
|
for image_id in self.img_ids: |
|
ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id) |
|
assert len(ann_ids) == 1 |
|
img_info = self.refexp_gt.loadImgs(image_id)[0] |
|
|
|
target = self.refexp_gt.loadAnns(ann_ids[0]) |
|
prediction = self.predictions[image_id] |
|
assert prediction is not None |
|
sorted_scores_boxes = sorted( |
|
zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True |
|
) |
|
sorted_scores, sorted_boxes = zip(*sorted_scores_boxes) |
|
sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes]) |
|
target_bbox = target[0]["bbox"] |
|
converted_bbox = [ |
|
target_bbox[0], |
|
target_bbox[1], |
|
target_bbox[2] + target_bbox[0], |
|
target_bbox[3] + target_bbox[1], |
|
] |
|
giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4)) |
|
for k in self.k: |
|
if max(giou[:k]) >= self.thresh_iou: |
|
dataset2score[img_info["dataset_name"]][k] += 1.0 |
|
dataset2count[img_info["dataset_name"]] += 1.0 |
|
|
|
for key, value in dataset2score.items(): |
|
for k in self.k: |
|
try: |
|
value[k] /= dataset2count[key] |
|
except: |
|
pass |
|
results = {} |
|
for key, value in dataset2score.items(): |
|
results[key] = sorted([v for k, v in value.items()]) |
|
print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n") |
|
|
|
return results |
|
return None |
|
|
|
|
|
|