import torch import torch.nn as nn from torchvision.ops.boxes import batched_nms from util import box_ops class CondNMSPostProcess(nn.Module): def __init__(self, num_queries): super(CondNMSPostProcess, self).__init__() self.num_queries = num_queries @torch.no_grad() def forward(self, outputs, target_sizes, pred_names, mask_infos): out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] bs = len(out_logits) results = [] for b in range(bs): b_scores, b_boxes, b_names = [], [], [] b_start_id, b_end_id = [], [] name = [] for name_i in pred_names[b]: name.append([name_i] * self.num_queries) start_id, end_id = [], [] for (start, end) in mask_infos[b].keys(): start_id.append([start] * self.num_queries) end_id.append([end] * self.num_queries) prob = out_logits[b][0][:, -1:].sigmoid() if len(prob) == 0: continue boxes = box_ops.box_cxcywh_to_xyxy(out_bbox[b][0]) img_h, img_w = target_sizes[b] scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=0) boxes = boxes * scale_fct[None, :] num_patch = len(prob) // self.num_queries prob = prob.view(num_patch, self.num_queries, -1) boxes = boxes.view(num_patch, self.num_queries, -1) for t in range(num_patch): ind = prob[t].squeeze(1).topk(100).indices prob_prenms = prob[t][ind] box_prenms = boxes[t][ind] lbl_prenms = torch.zeros_like(prob_prenms) nms_ind = batched_nms(box_prenms, prob_prenms[:, 0], lbl_prenms[:, 0], 0.7)[:20] b_scores.append(prob_prenms[nms_ind]) b_boxes.append(box_prenms[nms_ind]) b_names += [name[t][int(i)] for i in nms_ind] b_start_id += [start_id[t][int(i)] for i in nms_ind] b_end_id += [end_id[t][int(i)] for i in nms_ind] b_scores = torch.cat(b_scores).cpu().squeeze(1) b_boxes = torch.cat(b_boxes).cpu() out = {'scores': b_scores, 'boxes': b_boxes, 'names': b_names, 'start_id': b_start_id, 'end_id': b_end_id} results.append(out) return results