Inference Endpoints
GRiT / grit /modeling /roi_heads /grit_roi_heads.py
Vishakaraj's picture
Upload 1797 files
a567fa4
import math
import torch
from typing import Dict, List, Optional, Tuple, Union
from detectron2.config import configurable
from detectron2.structures import Boxes, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
from detectron2.modeling.poolers import ROIPooler
from detectron2.layers import batched_nms
from .grit_fast_rcnn import GRiTFastRCNNOutputLayers
from ..text.text_decoder import TransformerDecoderTextualHead, GRiTTextDecoder, AutoRegressiveBeamSearch
from ..text.load_text_token import LoadTextTokens
from transformers import BertTokenizer
from grit.data.custom_dataset_mapper import ObjDescription
from ..soft_nms import batched_soft_nms
import logging
logger = logging.getLogger(__name__)
@ROI_HEADS_REGISTRY.register()
class GRiTROIHeadsAndTextDecoder(CascadeROIHeads):
@configurable
def __init__(
self,
*,
text_decoder_transformer,
train_task: list,
test_task: str,
mult_proposal_score: bool = False,
mask_weight: float = 1.0,
object_feat_pooler=None,
soft_nms_enabled=False,
beam_size=1,
**kwargs,
):
super().__init__(**kwargs)
self.mult_proposal_score = mult_proposal_score
self.mask_weight = mask_weight
self.object_feat_pooler = object_feat_pooler
self.soft_nms_enabled = soft_nms_enabled
self.test_task = test_task
self.beam_size = beam_size
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
self.tokenizer = tokenizer
assert test_task in train_task, 'GRiT has not been trained on {} task, ' \
'please verify the task name or train a new ' \
'GRiT on {} task'.format(test_task, test_task)
task_begin_tokens = {}
for i, task in enumerate(train_task):
if i == 0:
task_begin_tokens[task] = tokenizer.cls_token_id
else:
task_begin_tokens[task] = 103 + i
self.task_begin_tokens = task_begin_tokens
beamsearch_decode = AutoRegressiveBeamSearch(
end_token_id=tokenizer.sep_token_id,
max_steps=40,
beam_size=beam_size,
objectdet=test_task == "ObjectDet",
per_node_beam_size=1,
)
self.text_decoder = GRiTTextDecoder(
text_decoder_transformer,
beamsearch_decode=beamsearch_decode,
begin_token_id=task_begin_tokens[test_task],
loss_type='smooth',
tokenizer=tokenizer,
)
self.get_target_text_tokens = LoadTextTokens(tokenizer, max_text_len=40, padding='do_not_pad')
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
text_decoder_transformer = TransformerDecoderTextualHead(
object_feature_size=cfg.MODEL.FPN.OUT_CHANNELS,
vocab_size=cfg.TEXT_DECODER.VOCAB_SIZE,
hidden_size=cfg.TEXT_DECODER.HIDDEN_SIZE,
num_layers=cfg.TEXT_DECODER.NUM_LAYERS,
attention_heads=cfg.TEXT_DECODER.ATTENTION_HEADS,
feedforward_size=cfg.TEXT_DECODER.FEEDFORWARD_SIZE,
mask_future_positions=True,
padding_idx=0,
decoder_type='bert_en',
use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
)
ret.update({
'text_decoder_transformer': text_decoder_transformer,
'train_task': cfg.MODEL.TRAIN_TASK,
'test_task': cfg.MODEL.TEST_TASK,
'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT,
'soft_nms_enabled': cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED,
'beam_size': cfg.MODEL.BEAM_SIZE,
})
return ret
@classmethod
def _init_box_head(self, cfg, input_shape):
ret = super()._init_box_head(cfg, input_shape)
del ret['box_predictors']
cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
box_predictors = []
for box_head, bbox_reg_weights in zip(ret['box_heads'], \
cascade_bbox_reg_weights):
box_predictors.append(
GRiTFastRCNNOutputLayers(
cfg, box_head.output_shape,
box2box_transform=Box2BoxTransform(weights=bbox_reg_weights)
))
ret['box_predictors'] = box_predictors
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
object_feat_pooler = ROIPooler(
output_size=cfg.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type=pooler_type,
)
ret['object_feat_pooler'] = object_feat_pooler
return ret
def check_if_all_background(self, proposals, targets, stage):
all_background = True
for proposals_per_image in proposals:
if not (proposals_per_image.gt_classes == self.num_classes).all():
all_background = False
if all_background:
logger.info('all proposals are background at stage {}'.format(stage))
proposals[0].proposal_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[0, :]
proposals[0].gt_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[0, :]
proposals[0].objectness_logits[0] = math.log((1.0 - 1e-10) / (1 - (1.0 - 1e-10)))
proposals[0].gt_classes[0] = targets[0].gt_classes[0]
proposals[0].gt_object_descriptions.data[0] = targets[0].gt_object_descriptions.data[0]
if 'foreground' in proposals[0].get_fields().keys():
proposals[0].foreground[0] = 1
return proposals
def _forward_box(self, features, proposals, targets=None, task="ObjectDet"):
if self.training:
proposals = self.check_if_all_background(proposals, targets, 0)
if (not self.training) and self.mult_proposal_score:
if len(proposals) > 0 and proposals[0].has('scores'):
proposal_scores = [p.get('scores') for p in proposals]
else:
proposal_scores = [p.get('objectness_logits') for p in proposals]
features = [features[f] for f in self.box_in_features]
head_outputs = []
prev_pred_boxes = None
image_sizes = [x.image_size for x in proposals]
for k in range(self.num_cascade_stages):
if k > 0:
proposals = self._create_proposals_from_boxes(
prev_pred_boxes, image_sizes,
logits=[p.objectness_logits for p in proposals])
if self.training:
proposals = self._match_and_label_boxes_GRiT(
proposals, k, targets)
proposals = self.check_if_all_background(proposals, targets, k)
predictions = self._run_stage(features, proposals, k)
prev_pred_boxes = self.box_predictor[k].predict_boxes(
(predictions[0], predictions[1]), proposals)
head_outputs.append((self.box_predictor[k], predictions, proposals))
if self.training:
object_features = self.object_feat_pooler(features, [x.proposal_boxes for x in proposals])
object_features = _ScaleGradient.apply(object_features, 1.0 / self.num_cascade_stages)
foreground = torch.cat([x.foreground for x in proposals])
object_features = object_features[foreground > 0]
object_descriptions = []
for x in proposals:
object_descriptions += x.gt_object_descriptions[x.foreground > 0].data
object_descriptions = ObjDescription(object_descriptions)
object_descriptions = object_descriptions.data
if len(object_descriptions) > 0:
begin_token = self.task_begin_tokens[task]
text_decoder_inputs = self.get_target_text_tokens(object_descriptions, object_features, begin_token)
object_features = object_features.view(
object_features.shape[0], object_features.shape[1], -1).permute(0, 2, 1).contiguous()
text_decoder_inputs.update({'object_features': object_features})
text_decoder_loss = self.text_decoder(text_decoder_inputs)
else:
text_decoder_loss = head_outputs[0][1][0].new_zeros([1])[0]
losses = {}
storage = get_event_storage()
# RoI Head losses (For the proposal generator loss, please find it in grit.py)
for stage, (predictor, predictions, proposals) in enumerate(head_outputs):
with storage.name_scope("stage{}".format(stage)):
stage_losses = predictor.losses(
(predictions[0], predictions[1]), proposals)
losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()})
# Text Decoder loss
losses.update({'text_decoder_loss': text_decoder_loss})
return losses
else:
scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
logits_per_stage = [(h[1][0],) for h in head_outputs]
scores = [
sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
for scores_per_image in zip(*scores_per_stage)
]
logits = [
sum(list(logits_per_image)) * (1.0 / self.num_cascade_stages)
for logits_per_image in zip(*logits_per_stage)
]
if self.mult_proposal_score:
scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)]
predictor, predictions, proposals = head_outputs[-1]
boxes = predictor.predict_boxes(
(predictions[0], predictions[1]), proposals)
assert len(boxes) == 1
pred_instances, _ = self.fast_rcnn_inference_GRiT(
boxes,
scores,
logits,
image_sizes,
predictor.test_score_thresh,
predictor.test_nms_thresh,
predictor.test_topk_per_image,
self.soft_nms_enabled,
)
assert len(pred_instances) == 1, "Only support one image"
for i, pred_instance in enumerate(pred_instances):
if len(pred_instance.pred_boxes) > 0:
object_features = self.object_feat_pooler(features, [pred_instance.pred_boxes])
object_features = object_features.view(
object_features.shape[0], object_features.shape[1], -1).permute(0, 2, 1).contiguous()
text_decoder_output = self.text_decoder({'object_features': object_features})
if self.beam_size > 1 and self.test_task == "ObjectDet":
pred_boxes = []
pred_scores = []
pred_classes = []
pred_object_descriptions = []
for beam_id in range(self.beam_size):
pred_boxes.append(pred_instance.pred_boxes.tensor)
# object score = sqrt(objectness score x description score)
pred_scores.append((pred_instance.scores *
torch.exp(text_decoder_output['logprobs'])[:, beam_id]) ** 0.5)
pred_classes.append(pred_instance.pred_classes)
for prediction in text_decoder_output['predictions'][:, beam_id, :]:
# convert text tokens to words
description = self.tokenizer.decode(prediction.tolist()[1:], skip_special_tokens=True)
pred_object_descriptions.append(description)
merged_instances = Instances(image_sizes[0])
if torch.cat(pred_scores, dim=0).shape[0] <= predictor.test_topk_per_image:
merged_instances.scores = torch.cat(pred_scores, dim=0)
merged_instances.pred_boxes = Boxes(torch.cat(pred_boxes, dim=0))
merged_instances.pred_classes = torch.cat(pred_classes, dim=0)
merged_instances.pred_object_descriptions = ObjDescription(pred_object_descriptions)
else:
pred_scores, top_idx = torch.topk(
torch.cat(pred_scores, dim=0), predictor.test_topk_per_image)
merged_instances.scores = pred_scores
merged_instances.pred_boxes = Boxes(torch.cat(pred_boxes, dim=0)[top_idx, :])
merged_instances.pred_classes = torch.cat(pred_classes, dim=0)[top_idx]
merged_instances.pred_object_descriptions = \
ObjDescription(ObjDescription(pred_object_descriptions)[top_idx].data)
pred_instances[i] = merged_instances
else:
# object score = sqrt(objectness score x description score)
pred_instance.scores = (pred_instance.scores *
torch.exp(text_decoder_output['logprobs'])) ** 0.5
pred_object_descriptions = []
for prediction in text_decoder_output['predictions']:
# convert text tokens to words
description = self.tokenizer.decode(prediction.tolist()[1:], skip_special_tokens=True)
pred_object_descriptions.append(description)
pred_instance.pred_object_descriptions = ObjDescription(pred_object_descriptions)
else:
pred_instance.pred_object_descriptions = ObjDescription([])
return pred_instances
def forward(self, features, proposals, targets=None, targets_task="ObjectDet"):
if self.training:
proposals = self.label_and_sample_proposals(
proposals, targets)
losses = self._forward_box(features, proposals, targets, task=targets_task)
if targets[0].has('gt_masks'):
mask_losses = self._forward_mask(features, proposals)
losses.update({k: v * self.mask_weight \
for k, v in mask_losses.items()})
else:
losses.update(self._get_empty_mask_loss(device=proposals[0].objectness_logits.device))
return proposals, losses
else:
pred_instances = self._forward_box(features, proposals, task=self.test_task)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
return pred_instances, {}
@torch.no_grad()
def _match_and_label_boxes_GRiT(self, proposals, stage, targets):
"""
Add "gt_object_description" and "foreground" to detectron2's _match_and_label_boxes
"""
num_fg_samples, num_bg_samples = [], []
for proposals_per_image, targets_per_image in zip(proposals, targets):
match_quality_matrix = pairwise_iou(
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
)
# proposal_labels are 0 or 1
matched_idxs, proposal_labels = self.proposal_matchers[stage](match_quality_matrix)
if len(targets_per_image) > 0:
gt_classes = targets_per_image.gt_classes[matched_idxs]
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
gt_classes[proposal_labels == 0] = self.num_classes
foreground = torch.ones_like(gt_classes)
foreground[proposal_labels == 0] = 0
gt_boxes = targets_per_image.gt_boxes[matched_idxs]
gt_object_descriptions = targets_per_image.gt_object_descriptions[matched_idxs]
else:
gt_classes = torch.zeros_like(matched_idxs) + self.num_classes
foreground = torch.zeros_like(gt_classes)
gt_boxes = Boxes(
targets_per_image.gt_boxes.tensor.new_zeros((len(proposals_per_image), 4))
)
gt_object_descriptions = ObjDescription(['None' for i in range(len(proposals_per_image))])
proposals_per_image.gt_classes = gt_classes
proposals_per_image.gt_boxes = gt_boxes
proposals_per_image.gt_object_descriptions = gt_object_descriptions
proposals_per_image.foreground = foreground
num_fg_samples.append((proposal_labels == 1).sum().item())
num_bg_samples.append(proposal_labels.numel() - num_fg_samples[-1])
# Log the number of fg/bg samples in each stage
storage = get_event_storage()
storage.put_scalar(
"stage{}/roi_head/num_fg_samples".format(stage),
sum(num_fg_samples) / len(num_fg_samples),
)
storage.put_scalar(
"stage{}/roi_head/num_bg_samples".format(stage),
sum(num_bg_samples) / len(num_bg_samples),
)
return proposals
def fast_rcnn_inference_GRiT(
self,
boxes: List[torch.Tensor],
scores: List[torch.Tensor],
logits: List[torch.Tensor],
image_shapes: List[Tuple[int, int]],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
soft_nms_enabled: bool,
):
result_per_image = [
self.fast_rcnn_inference_single_image_GRiT(
boxes_per_image, scores_per_image, logits_per_image, image_shape,
score_thresh, nms_thresh, topk_per_image, soft_nms_enabled
)
for scores_per_image, boxes_per_image, image_shape, logits_per_image \
in zip(scores, boxes, image_shapes, logits)
]
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
def fast_rcnn_inference_single_image_GRiT(
self,
boxes,
scores,
logits,
image_shape: Tuple[int, int],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
soft_nms_enabled,
):
"""
Add soft NMS to detectron2's fast_rcnn_inference_single_image
"""
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all():
boxes = boxes[valid_mask]
scores = scores[valid_mask]
logits = logits[valid_mask]
scores = scores[:, :-1]
logits = logits[:, :-1]
num_bbox_reg_classes = boxes.shape[1] // 4
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
# 1. Filter results based on detection scores. It can make NMS more efficient
# by filtering out low-confidence detections.
filter_mask = scores > score_thresh # R x K
# R' x 2. First column contains indices of the R predictions;
# Second column contains indices of classes.
filter_inds = filter_mask.nonzero()
if num_bbox_reg_classes == 1:
boxes = boxes[filter_inds[:, 0], 0]
else:
boxes = boxes[filter_mask]
scores = scores[filter_mask]
logits = logits[filter_mask]
# 2. Apply NMS for each class independently.
if not soft_nms_enabled:
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
else:
keep, soft_nms_scores = batched_soft_nms(
boxes,
scores,
filter_inds[:, 1],
"linear",
0.5,
nms_thresh,
0.001,
)
scores[keep] = soft_nms_scores
if topk_per_image >= 0:
keep = keep[:topk_per_image]
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
logits = logits[keep]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = filter_inds[:, 1]
result.logits = logits
return result, filter_inds[:, 0]
def _get_empty_mask_loss(self, device):
if self.mask_on:
return {'loss_mask': torch.zeros(
(1, ), device=device, dtype=torch.float32)[0]}
else:
return {}
def _create_proposals_from_boxes(self, boxes, image_sizes, logits):
boxes = [Boxes(b.detach()) for b in boxes]
proposals = []
for boxes_per_image, image_size, logit in zip(
boxes, image_sizes, logits):
boxes_per_image.clip(image_size)
if self.training:
inds = boxes_per_image.nonempty()
boxes_per_image = boxes_per_image[inds]
logit = logit[inds]
prop = Instances(image_size)
prop.proposal_boxes = boxes_per_image
prop.objectness_logits = logit
proposals.append(prop)
return proposals
def _run_stage(self, features, proposals, stage):
pool_boxes = [x.proposal_boxes for x in proposals]
box_features = self.box_pooler(features, pool_boxes)
box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
box_features = self.box_head[stage](box_features)
return self.box_predictor[stage](box_features)