ContextDet-Demo / models /post_process.py
yuhangzang
update
a059c46
import torch
import torch.nn as nn
from torchvision.ops.boxes import batched_nms
from util import box_ops
class CondNMSPostProcess(nn.Module):
def __init__(self, num_queries):
super(CondNMSPostProcess, self).__init__()
self.num_queries = num_queries
@torch.no_grad()
def forward(self, outputs, target_sizes, pred_names, mask_infos):
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
bs = len(out_logits)
results = []
for b in range(bs):
b_scores, b_boxes, b_names = [], [], []
b_start_id, b_end_id = [], []
name = []
for name_i in pred_names[b]:
name.append([name_i] * self.num_queries)
start_id, end_id = [], []
for (start, end) in mask_infos[b].keys():
start_id.append([start] * self.num_queries)
end_id.append([end] * self.num_queries)
prob = out_logits[b][0][:, -1:].sigmoid()
if len(prob) == 0:
continue
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox[b][0])
img_h, img_w = target_sizes[b]
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=0)
boxes = boxes * scale_fct[None, :]
num_patch = len(prob) // self.num_queries
prob = prob.view(num_patch, self.num_queries, -1)
boxes = boxes.view(num_patch, self.num_queries, -1)
for t in range(num_patch):
ind = prob[t].squeeze(1).topk(100).indices
prob_prenms = prob[t][ind]
box_prenms = boxes[t][ind]
lbl_prenms = torch.zeros_like(prob_prenms)
nms_ind = batched_nms(box_prenms, prob_prenms[:, 0], lbl_prenms[:, 0], 0.7)[:20]
b_scores.append(prob_prenms[nms_ind])
b_boxes.append(box_prenms[nms_ind])
b_names += [name[t][int(i)] for i in nms_ind]
b_start_id += [start_id[t][int(i)] for i in nms_ind]
b_end_id += [end_id[t][int(i)] for i in nms_ind]
b_scores = torch.cat(b_scores).cpu().squeeze(1)
b_boxes = torch.cat(b_boxes).cpu()
out = {'scores': b_scores, 'boxes': b_boxes, 'names': b_names,
'start_id': b_start_id, 'end_id': b_end_id}
results.append(out)
return results