import logging |
import math |
import numpy as np |
from typing import Dict, List, Tuple |
import fvcore.nn.weight_init as weight_init |
import torch |
from torch import Tensor, nn |
from torch.nn import functional as F |
from detectron2.config import configurable |
from detectron2.layers import Conv2d, ShapeSpec, cat, interpolate |
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY |
from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference, mask_rcnn_loss |
from detectron2.structures import Boxes |
from .point_features import ( |
generate_regular_grid_point_coords, |
get_point_coords_wrt_image, |
get_uncertain_point_coords_on_grid, |
get_uncertain_point_coords_with_randomness, |
point_sample, |
point_sample_fine_grained_features, |
sample_point_labels, |
) |
from .point_head import build_point_head, roi_mask_point_loss |
def calculate_uncertainty(logits, classes): |
""" |
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the |
foreground class in `classes`. |
Args: |
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or |
class-agnostic, where R is the total number of predicted masks in all images and C is |
the number of foreground classes. The values are logits. |
classes (list): A list of length R that contains either predicted of ground truth class |
for eash predicted mask. |
Returns: |
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with |
the most uncertain locations having the highest uncertainty score. |
""" |
if logits.shape[1] == 1: |
gt_class_logits = logits.clone() |
else: |
gt_class_logits = logits[ |
torch.arange(logits.shape[0], device=logits.device), classes |
].unsqueeze(1) |
return -(torch.abs(gt_class_logits)) |
class ConvFCHead(nn.Module): |
""" |
A mask head with fully connected layers. Given pooled features it first reduces channels and |
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously |
to the standard box head. |
""" |
_version = 2 |
@configurable |
def __init__( |
self, input_shape: ShapeSpec, *, conv_dim: int, fc_dims: List[int], output_shape: Tuple[int] |
): |
""" |
Args: |
conv_dim: the output dimension of the conv layers |
fc_dims: a list of N>0 integers representing the output dimensions of N FC layers |
output_shape: shape of the output mask prediction |
""" |
super().__init__() |
input_channels = input_shape.channels |
input_h = input_shape.height |
input_w = input_shape.width |
self.output_shape = output_shape |
self.conv_layers = [] |
if input_channels > conv_dim: |
self.reduce_channel_dim_conv = Conv2d( |
input_channels, |
conv_dim, |
kernel_size=1, |
stride=1, |
padding=0, |
bias=True, |
activation=F.relu, |
) |
self.conv_layers.append(self.reduce_channel_dim_conv) |
self.reduce_spatial_dim_conv = Conv2d( |
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu |
) |
self.conv_layers.append(self.reduce_spatial_dim_conv) |
input_dim = conv_dim * input_h * input_w |
input_dim //= 4 |
self.fcs = [] |
for k, fc_dim in enumerate(fc_dims): |
fc = nn.Linear(input_dim, fc_dim) |
self.add_module("fc{}".format(k + 1), fc) |
self.fcs.append(fc) |
input_dim = fc_dim |
output_dim = int(np.prod(self.output_shape)) |
self.prediction = nn.Linear(fc_dims[-1], output_dim) |
nn.init.normal_(self.prediction.weight, std=0.001) |
nn.init.constant_(self.prediction.bias, 0) |
for layer in self.conv_layers: |
weight_init.c2_msra_fill(layer) |
for layer in self.fcs: |
weight_init.c2_xavier_fill(layer) |
@classmethod |
def from_config(cls, cfg, input_shape): |
output_shape = ( |
) |
ret = dict( |
input_shape=input_shape, |
fc_dims=[fc_dim] * num_fc, |
output_shape=output_shape, |
) |
return ret |
def forward(self, x): |
N = x.shape[0] |
for layer in self.conv_layers: |
x = layer(x) |
x = torch.flatten(x, start_dim=1) |
for layer in self.fcs: |
x = F.relu(layer(x)) |
output_shape = [N] + list(self.output_shape) |
return self.prediction(x).view(*output_shape) |
def _load_from_state_dict( |
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
): |
version = local_metadata.get("version", None) |
if version is None or version < 2: |
logger = logging.getLogger(__name__) |
logger.warning( |
"Weight format of PointRend models have changed! " |
"Applying automatic conversion now ..." |
) |
for k in list(state_dict.keys()): |
newk = k |
if k.startswith(prefix + "coarse_mask_fc"): |
newk = k.replace(prefix + "coarse_mask_fc", prefix + "fc") |
if newk != k: |
state_dict[newk] = state_dict[k] |
del state_dict[k] |
class PointRendMaskHead(nn.Module): |
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
super().__init__() |
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} |
self._init_point_head(cfg, input_shape) |
self.roi_pooler_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES |
self.roi_pooler_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION |
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} |
in_channels = np.sum([input_shape[f].channels for f in self.roi_pooler_in_features]) |
self._init_roi_head( |
cfg, |
ShapeSpec( |
channels=in_channels, |
width=self.roi_pooler_size, |
height=self.roi_pooler_size, |
), |
) |
def _init_roi_head(self, cfg, input_shape): |
self.coarse_head = ConvFCHead(cfg, input_shape) |
def _init_point_head(self, cfg, input_shape): |
self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON |
if not self.mask_point_on: |
return |
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES |
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS |
self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO |
self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO |
self.mask_point_subdivision_init_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION |
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS |
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS |
in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) |
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) |
while ( |
4 * self.mask_point_subdivision_init_resolution ** 2 |
<= self.mask_point_subdivision_num_points |
): |
self.mask_point_subdivision_init_resolution *= 2 |
self.mask_point_subdivision_steps -= 1 |
def forward(self, features, instances): |
""" |
Args: |
features (dict[str, Tensor]): a dict of image-level features |
instances (list[Instances]): proposals in training; detected |
instances in inference |
""" |
if self.training: |
proposal_boxes = [x.proposal_boxes for x in instances] |
coarse_mask = self.coarse_head(self._roi_pooler(features, proposal_boxes)) |
losses = {"loss_mask": mask_rcnn_loss(coarse_mask, instances)} |
if not self.mask_point_on: |
return losses |
point_coords, point_labels = self._sample_train_points(coarse_mask, instances) |
point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) |
point_logits = self._get_point_logits( |
point_fine_grained_features, point_coords, coarse_mask |
) |
losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) |
return losses |
else: |
pred_boxes = [x.pred_boxes for x in instances] |
coarse_mask = self.coarse_head(self._roi_pooler(features, pred_boxes)) |
return self._subdivision_inference(features, coarse_mask, instances) |
def _roi_pooler(self, features: List[Tensor], boxes: List[Boxes]): |
""" |
Extract per-box feature. This is similar to RoIAlign(sampling_ratio=1) except: |
1. It's implemented by point_sample |
2. It pools features across all levels and concat them, while typically |
RoIAlign select one level for every box. However in the config we only use |
one level (p2) so there is no difference. |
Returns: |
Tensor of shape (R, C, pooler_size, pooler_size) where R is the total number of boxes |
""" |
features_list = [features[k] for k in self.roi_pooler_in_features] |
features_scales = [self._feature_scales[k] for k in self.roi_pooler_in_features] |
num_boxes = sum(x.tensor.size(0) for x in boxes) |
output_size = self.roi_pooler_size |
point_coords = generate_regular_grid_point_coords(num_boxes, output_size, boxes[0].device) |
roi_features, _ = point_sample_fine_grained_features( |
features_list, features_scales, boxes, point_coords |
) |
return roi_features.view(num_boxes, roi_features.shape[1], output_size, output_size) |
def _sample_train_points(self, coarse_mask, instances): |
assert self.training |
gt_classes = cat([x.gt_classes for x in instances]) |
with torch.no_grad(): |
point_coords = get_uncertain_point_coords_with_randomness( |
coarse_mask, |
lambda logits: calculate_uncertainty(logits, gt_classes), |
self.mask_point_train_num_points, |
self.mask_point_oversample_ratio, |
self.mask_point_importance_sample_ratio, |
) |
proposal_boxes = [x.proposal_boxes for x in instances] |
cat_boxes = Boxes.cat(proposal_boxes) |
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) |
point_labels = sample_point_labels(instances, point_coords_wrt_image) |
return point_coords, point_labels |
def _point_pooler(self, features, proposal_boxes, point_coords): |
point_features_list = [features[k] for k in self.mask_point_in_features] |
point_features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] |
point_fine_grained_features, _ = point_sample_fine_grained_features( |
point_features_list, point_features_scales, proposal_boxes, point_coords |
) |
return point_fine_grained_features |
def _get_point_logits(self, point_fine_grained_features, point_coords, coarse_mask): |
coarse_features = point_sample(coarse_mask, point_coords, align_corners=False) |
point_logits = self.point_head(point_fine_grained_features, coarse_features) |
return point_logits |
def _subdivision_inference(self, features, mask_representations, instances): |
assert not self.training |
pred_boxes = [x.pred_boxes for x in instances] |
pred_classes = cat([x.pred_classes for x in instances]) |
mask_logits = None |
for _ in range(self.mask_point_subdivision_steps + 1): |
if mask_logits is None: |
point_coords = generate_regular_grid_point_coords( |
pred_classes.size(0), |
self.mask_point_subdivision_init_resolution, |
pred_boxes[0].device, |
) |
else: |
mask_logits = interpolate( |
mask_logits, scale_factor=2, mode="bilinear", align_corners=False |
) |
uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) |
point_indices, point_coords = get_uncertain_point_coords_on_grid( |
uncertainty_map, self.mask_point_subdivision_num_points |
) |
fine_grained_features = self._point_pooler(features, pred_boxes, point_coords) |
point_logits = self._get_point_logits( |
fine_grained_features, point_coords, mask_representations |
) |
if mask_logits is None: |
R, C, _ = point_logits.shape |
mask_logits = point_logits.reshape( |
R, |
C, |
self.mask_point_subdivision_init_resolution, |
self.mask_point_subdivision_init_resolution, |
) |
if len(pred_classes) == 0: |
mask_rcnn_inference(mask_logits, instances) |
return instances |
else: |
R, C, H, W = mask_logits.shape |
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) |
mask_logits = ( |
mask_logits.reshape(R, C, H * W) |
.scatter_(2, point_indices, point_logits) |
.view(R, C, H, W) |
) |
mask_rcnn_inference(mask_logits, instances) |
return instances |
class ImplicitPointRendMaskHead(PointRendMaskHead): |
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
super().__init__(cfg, input_shape) |
def _init_roi_head(self, cfg, input_shape): |
assert hasattr(self, "num_params"), "Please initialize point_head first!" |
self.parameter_head = ConvFCHead(cfg, input_shape, output_shape=(self.num_params,)) |
def _init_point_head(self, cfg, input_shape): |
self.mask_point_on = True |
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES |
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS |
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS |
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS |
in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) |
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) |
self.num_params = self.point_head.num_params |
self.mask_point_subdivision_init_resolution = int( |
math.sqrt(self.mask_point_subdivision_num_points) |
) |
assert ( |
self.mask_point_subdivision_init_resolution |
* self.mask_point_subdivision_init_resolution |
== self.mask_point_subdivision_num_points |
) |
def forward(self, features, instances): |
""" |
Args: |
features (dict[str, Tensor]): a dict of image-level features |
instances (list[Instances]): proposals in training; detected |
instances in inference |
""" |
if self.training: |
proposal_boxes = [x.proposal_boxes for x in instances] |
parameters = self.parameter_head(self._roi_pooler(features, proposal_boxes)) |
losses = {"loss_l2": self.regularizer * (parameters ** 2).mean()} |
point_coords, point_labels = self._uniform_sample_train_points(instances) |
point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) |
point_logits = self._get_point_logits( |
point_fine_grained_features, point_coords, parameters |
) |
losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) |
return losses |
else: |
pred_boxes = [x.pred_boxes for x in instances] |
parameters = self.parameter_head(self._roi_pooler(features, pred_boxes)) |
return self._subdivision_inference(features, parameters, instances) |
def _uniform_sample_train_points(self, instances): |
assert self.training |
proposal_boxes = [x.proposal_boxes for x in instances] |
cat_boxes = Boxes.cat(proposal_boxes) |
point_coords = torch.rand( |
len(cat_boxes), self.mask_point_train_num_points, 2, device=cat_boxes.tensor.device |
) |
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) |
point_labels = sample_point_labels(instances, point_coords_wrt_image) |
return point_coords, point_labels |
def _get_point_logits(self, fine_grained_features, point_coords, parameters): |
return self.point_head(fine_grained_features, point_coords, parameters) |