File size: 8,270 Bytes
9b855a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
"""Postprocessors class to transform MDETR output according to the downstream task"""
from typing import Dict
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import pycocotools.mask as mask_util
from util import box_ops
class A2DSentencesPostProcess(nn.Module):
"""
This module converts the model's output into the format expected by the coco api for the given task
"""
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
@torch.no_grad()
def forward(self, outputs, orig_target_sizes, max_target_sizes):
""" Perform the computation
Parameters:
outputs: raw outputs of the model
orig_target_sizes: original size of the samples (no augmentations or padding)
max_target_sizes: size of samples (input to model) after size augmentation.
NOTE: the max_padding_size is 4x out_masks.shape[-2:]
"""
assert len(orig_target_sizes) == len(max_target_sizes)
# there is only one valid frames, thus T=1
out_logits = outputs['pred_logits'][:, 0, :, 0] # [B, T, N, 1] -> [B, N]
out_masks = outputs['pred_masks'][:, 0, :, :, :] # [B, T, N, out_h, out_w] -> [B, N, out_h, out_w]
out_h, out_w = out_masks.shape[-2:]
scores = out_logits.sigmoid()
pred_masks = F.interpolate(out_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False) # [B, N, H, W]
pred_masks = (pred_masks.sigmoid() > 0.5) # [B, N, H, W]
processed_pred_masks, rle_masks = [], []
# for each batch
for f_pred_masks, resized_size, orig_size in zip(pred_masks, max_target_sizes, orig_target_sizes):
f_mask_h, f_mask_w = resized_size # resized shape without padding
f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, :f_mask_w].unsqueeze(1) # remove the samples' padding, [:, 1, h, w]
# resize the samples back to their original dataset (target) size for evaluation
f_pred_masks_processed = F.interpolate(f_pred_masks_no_pad.float(), size=tuple(orig_size.tolist()), mode="nearest") # origin size, [:, 1, h, w]
f_pred_rle_masks = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
for mask in f_pred_masks_processed.cpu()]
processed_pred_masks.append(f_pred_masks_processed)
rle_masks.append(f_pred_rle_masks)
predictions = [{'scores': s, 'masks': m, 'rle_masks': rle}
for s, m, rle in zip(scores, processed_pred_masks, rle_masks)]
return predictions
# PostProcess for pretraining
class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
@torch.no_grad()
def forward(self, outputs, target_sizes):
"""Perform the computation
Parameters:
outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
Returns:
"""
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
# coco, num_frames=1
out_logits = outputs["pred_logits"].flatten(1, 2)
out_boxes = outputs["pred_boxes"].flatten(1, 2)
bs, num_queries = out_logits.shape[:2]
prob = out_logits.sigmoid() # [bs, num_queries, num_classes]
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True)
scores = topk_values # [bs, num_queries]
topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries]
labels = topk_indexes % out_logits.shape[2] # [bs, num_queries]
boxes = box_ops.box_cxcywh_to_xyxy(out_boxes) # [bs, num_queries, 4]
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :] # [bs, num_queries, 4]
assert len(scores) == len(labels) == len(boxes)
# binary for the pretraining
results = [{"scores": s, "labels": torch.ones_like(l), "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results
class PostProcessSegm(nn.Module):
"""Similar to PostProcess but for segmentation masks.
This processor is to be called sequentially after PostProcess.
Args:
threshold: threshold that will be applied to binarize the segmentation masks.
"""
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
@torch.no_grad()
def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
"""Perform the computation
Parameters:
results: already pre-processed boxes (output of PostProcess) NOTE here
outputs: raw outputs of the model
orig_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
max_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
after data augmentation.
"""
assert len(orig_target_sizes) == len(max_target_sizes)
out_logits = outputs["pred_logits"].flatten(1, 2)
out_masks = outputs["pred_masks"].flatten(1, 2)
bs, num_queries = out_logits.shape[:2]
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True)
scores = topk_values # [bs, num_queries]
topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries]
labels = topk_indexes % out_logits.shape[2] # [bs, num_queries]
outputs_masks = [out_m[topk_boxes[i]].unsqueeze(0) for i, out_m, in enumerate(out_masks)] # list[Tensor]
outputs_masks = torch.cat(outputs_masks, dim=0) # [bs, num_queries, H, W]
out_h, out_w = outputs_masks.shape[-2:]
# max_h, max_w = max_target_sizes.max(0)[0].tolist()
# outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
outputs_masks = F.interpolate(outputs_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False)
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
img_h, img_w = t[0], t[1]
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) # [:, 1, h, w]
results[i]["masks"] = F.interpolate(
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
).byte()
results[i]["rle_masks"] = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
for mask in results[i]["masks"].cpu()]
return results
def build_postprocessors(args, dataset_name):
if dataset_name == 'a2d' or dataset_name == 'jhmdb':
postprocessors = A2DSentencesPostProcess(threshold=args.threshold)
else:
# for coco pretrain postprocessor
postprocessors: Dict[str, nn.Module] = {"bbox": PostProcess()}
if args.masks:
postprocessors["segm"] = PostProcessSegm(threshold=args.threshold)
return postprocessors
|