|
|
|
|
|
import torch |
|
from fvcore.nn import giou_loss, smooth_l1_loss |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import fvcore.nn.weight_init as weight_init |
|
from detectron2.config import configurable |
|
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple |
|
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers |
|
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats |
|
|
|
|
|
__all__ = ["GRiTFastRCNNOutputLayers"] |
|
|
|
|
|
class GRiTFastRCNNOutputLayers(FastRCNNOutputLayers): |
|
@configurable |
|
def __init__( |
|
self, |
|
input_shape: ShapeSpec, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
input_shape=input_shape, |
|
**kwargs, |
|
) |
|
|
|
input_size = input_shape.channels * \ |
|
(input_shape.width or 1) * (input_shape.height or 1) |
|
|
|
self.bbox_pred = nn.Sequential( |
|
nn.Linear(input_size, input_size), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(input_size, 4) |
|
) |
|
weight_init.c2_xavier_fill(self.bbox_pred[0]) |
|
nn.init.normal_(self.bbox_pred[-1].weight, std=0.001) |
|
nn.init.constant_(self.bbox_pred[-1].bias, 0) |
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
ret = super().from_config(cfg, input_shape) |
|
return ret |
|
|
|
def losses(self, predictions, proposals): |
|
scores, proposal_deltas = predictions |
|
gt_classes = ( |
|
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) |
|
) |
|
num_classes = self.num_classes |
|
_log_classification_stats(scores, gt_classes) |
|
|
|
if len(proposals): |
|
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) |
|
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" |
|
gt_boxes = cat( |
|
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], |
|
dim=0, |
|
) |
|
else: |
|
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) |
|
|
|
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) |
|
return { |
|
"loss_cls": loss_cls, |
|
"loss_box_reg": self.box_reg_loss( |
|
proposal_boxes, gt_boxes, proposal_deltas, gt_classes, |
|
num_classes=num_classes) |
|
} |
|
|
|
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): |
|
if pred_class_logits.numel() == 0: |
|
return pred_class_logits.new_zeros([1])[0] |
|
|
|
loss = F.cross_entropy( |
|
pred_class_logits, gt_classes, reduction="mean") |
|
return loss |
|
|
|
def box_reg_loss( |
|
self, proposal_boxes, gt_boxes, pred_deltas, gt_classes, |
|
num_classes=-1): |
|
num_classes = num_classes if num_classes > 0 else self.num_classes |
|
box_dim = proposal_boxes.shape[1] |
|
fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0] |
|
if pred_deltas.shape[1] == box_dim: |
|
fg_pred_deltas = pred_deltas[fg_inds] |
|
else: |
|
fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[ |
|
fg_inds, gt_classes[fg_inds] |
|
] |
|
|
|
if self.box_reg_loss_type == "smooth_l1": |
|
gt_pred_deltas = self.box2box_transform.get_deltas( |
|
proposal_boxes[fg_inds], |
|
gt_boxes[fg_inds], |
|
) |
|
loss_box_reg = smooth_l1_loss( |
|
fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum" |
|
) |
|
elif self.box_reg_loss_type == "giou": |
|
fg_pred_boxes = self.box2box_transform.apply_deltas( |
|
fg_pred_deltas, proposal_boxes[fg_inds] |
|
) |
|
loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum") |
|
else: |
|
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") |
|
return loss_box_reg / max(gt_classes.numel(), 1.0) |
|
|
|
def predict_probs(self, predictions, proposals): |
|
scores = predictions[0] |
|
num_inst_per_image = [len(p) for p in proposals] |
|
probs = F.softmax(scores, dim=-1) |
|
return probs.split(num_inst_per_image, dim=0) |
|
|
|
def forward(self, x): |
|
if x.dim() > 2: |
|
x = torch.flatten(x, start_dim=1) |
|
scores = [] |
|
|
|
cls_scores = self.cls_score(x) |
|
scores.append(cls_scores) |
|
scores = torch.cat(scores, dim=1) |
|
|
|
proposal_deltas = self.bbox_pred(x) |
|
return scores, proposal_deltas |