|
|
|
import logging |
|
import math |
|
import numpy as np |
|
from typing import Dict, List, Tuple |
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import configurable |
|
from detectron2.layers import Conv2d, ShapeSpec, cat, interpolate |
|
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY |
|
from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference, mask_rcnn_loss |
|
from detectron2.structures import Boxes |
|
|
|
from .point_features import ( |
|
generate_regular_grid_point_coords, |
|
get_point_coords_wrt_image, |
|
get_uncertain_point_coords_on_grid, |
|
get_uncertain_point_coords_with_randomness, |
|
point_sample, |
|
point_sample_fine_grained_features, |
|
sample_point_labels, |
|
) |
|
from .point_head import build_point_head, roi_mask_point_loss |
|
|
|
|
|
def calculate_uncertainty(logits, classes): |
|
""" |
|
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the |
|
foreground class in `classes`. |
|
Args: |
|
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or |
|
class-agnostic, where R is the total number of predicted masks in all images and C is |
|
the number of foreground classes. The values are logits. |
|
classes (list): A list of length R that contains either predicted of ground truth class |
|
for eash predicted mask. |
|
Returns: |
|
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with |
|
the most uncertain locations having the highest uncertainty score. |
|
""" |
|
if logits.shape[1] == 1: |
|
gt_class_logits = logits.clone() |
|
else: |
|
gt_class_logits = logits[ |
|
torch.arange(logits.shape[0], device=logits.device), classes |
|
].unsqueeze(1) |
|
return -(torch.abs(gt_class_logits)) |
|
|
|
|
|
class ConvFCHead(nn.Module): |
|
""" |
|
A mask head with fully connected layers. Given pooled features it first reduces channels and |
|
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously |
|
to the standard box head. |
|
""" |
|
|
|
_version = 2 |
|
|
|
@configurable |
|
def __init__( |
|
self, input_shape: ShapeSpec, *, conv_dim: int, fc_dims: List[int], output_shape: Tuple[int] |
|
): |
|
""" |
|
Args: |
|
conv_dim: the output dimension of the conv layers |
|
fc_dims: a list of N>0 integers representing the output dimensions of N FC layers |
|
output_shape: shape of the output mask prediction |
|
""" |
|
super().__init__() |
|
|
|
|
|
input_channels = input_shape.channels |
|
input_h = input_shape.height |
|
input_w = input_shape.width |
|
self.output_shape = output_shape |
|
|
|
|
|
self.conv_layers = [] |
|
if input_channels > conv_dim: |
|
self.reduce_channel_dim_conv = Conv2d( |
|
input_channels, |
|
conv_dim, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=True, |
|
activation=F.relu, |
|
) |
|
self.conv_layers.append(self.reduce_channel_dim_conv) |
|
|
|
self.reduce_spatial_dim_conv = Conv2d( |
|
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu |
|
) |
|
self.conv_layers.append(self.reduce_spatial_dim_conv) |
|
|
|
input_dim = conv_dim * input_h * input_w |
|
input_dim //= 4 |
|
|
|
self.fcs = [] |
|
for k, fc_dim in enumerate(fc_dims): |
|
fc = nn.Linear(input_dim, fc_dim) |
|
self.add_module("fc{}".format(k + 1), fc) |
|
self.fcs.append(fc) |
|
input_dim = fc_dim |
|
|
|
output_dim = int(np.prod(self.output_shape)) |
|
|
|
self.prediction = nn.Linear(fc_dims[-1], output_dim) |
|
|
|
nn.init.normal_(self.prediction.weight, std=0.001) |
|
nn.init.constant_(self.prediction.bias, 0) |
|
|
|
for layer in self.conv_layers: |
|
weight_init.c2_msra_fill(layer) |
|
for layer in self.fcs: |
|
weight_init.c2_xavier_fill(layer) |
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
output_shape = ( |
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES, |
|
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, |
|
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, |
|
) |
|
fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM |
|
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC |
|
ret = dict( |
|
input_shape=input_shape, |
|
conv_dim=cfg.MODEL.ROI_MASK_HEAD.CONV_DIM, |
|
fc_dims=[fc_dim] * num_fc, |
|
output_shape=output_shape, |
|
) |
|
return ret |
|
|
|
def forward(self, x): |
|
N = x.shape[0] |
|
for layer in self.conv_layers: |
|
x = layer(x) |
|
x = torch.flatten(x, start_dim=1) |
|
for layer in self.fcs: |
|
x = F.relu(layer(x)) |
|
output_shape = [N] + list(self.output_shape) |
|
return self.prediction(x).view(*output_shape) |
|
|
|
def _load_from_state_dict( |
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
): |
|
version = local_metadata.get("version", None) |
|
|
|
if version is None or version < 2: |
|
logger = logging.getLogger(__name__) |
|
logger.warning( |
|
"Weight format of PointRend models have changed! " |
|
"Applying automatic conversion now ..." |
|
) |
|
for k in list(state_dict.keys()): |
|
newk = k |
|
if k.startswith(prefix + "coarse_mask_fc"): |
|
newk = k.replace(prefix + "coarse_mask_fc", prefix + "fc") |
|
if newk != k: |
|
state_dict[newk] = state_dict[k] |
|
del state_dict[k] |
|
|
|
|
|
@ROI_MASK_HEAD_REGISTRY.register() |
|
class PointRendMaskHead(nn.Module): |
|
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
|
super().__init__() |
|
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} |
|
|
|
self._init_point_head(cfg, input_shape) |
|
|
|
self.roi_pooler_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES |
|
self.roi_pooler_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION |
|
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} |
|
in_channels = np.sum([input_shape[f].channels for f in self.roi_pooler_in_features]) |
|
self._init_roi_head( |
|
cfg, |
|
ShapeSpec( |
|
channels=in_channels, |
|
width=self.roi_pooler_size, |
|
height=self.roi_pooler_size, |
|
), |
|
) |
|
|
|
def _init_roi_head(self, cfg, input_shape): |
|
self.coarse_head = ConvFCHead(cfg, input_shape) |
|
|
|
def _init_point_head(self, cfg, input_shape): |
|
|
|
self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON |
|
if not self.mask_point_on: |
|
return |
|
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES |
|
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES |
|
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS |
|
self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO |
|
self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO |
|
|
|
self.mask_point_subdivision_init_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION |
|
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS |
|
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS |
|
|
|
|
|
in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) |
|
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) |
|
|
|
|
|
|
|
while ( |
|
4 * self.mask_point_subdivision_init_resolution ** 2 |
|
<= self.mask_point_subdivision_num_points |
|
): |
|
self.mask_point_subdivision_init_resolution *= 2 |
|
self.mask_point_subdivision_steps -= 1 |
|
|
|
def forward(self, features, instances): |
|
""" |
|
Args: |
|
features (dict[str, Tensor]): a dict of image-level features |
|
instances (list[Instances]): proposals in training; detected |
|
instances in inference |
|
""" |
|
if self.training: |
|
proposal_boxes = [x.proposal_boxes for x in instances] |
|
coarse_mask = self.coarse_head(self._roi_pooler(features, proposal_boxes)) |
|
losses = {"loss_mask": mask_rcnn_loss(coarse_mask, instances)} |
|
if not self.mask_point_on: |
|
return losses |
|
|
|
point_coords, point_labels = self._sample_train_points(coarse_mask, instances) |
|
point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) |
|
point_logits = self._get_point_logits( |
|
point_fine_grained_features, point_coords, coarse_mask |
|
) |
|
losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) |
|
return losses |
|
else: |
|
pred_boxes = [x.pred_boxes for x in instances] |
|
coarse_mask = self.coarse_head(self._roi_pooler(features, pred_boxes)) |
|
return self._subdivision_inference(features, coarse_mask, instances) |
|
|
|
def _roi_pooler(self, features: List[Tensor], boxes: List[Boxes]): |
|
""" |
|
Extract per-box feature. This is similar to RoIAlign(sampling_ratio=1) except: |
|
1. It's implemented by point_sample |
|
2. It pools features across all levels and concat them, while typically |
|
RoIAlign select one level for every box. However in the config we only use |
|
one level (p2) so there is no difference. |
|
|
|
Returns: |
|
Tensor of shape (R, C, pooler_size, pooler_size) where R is the total number of boxes |
|
""" |
|
features_list = [features[k] for k in self.roi_pooler_in_features] |
|
features_scales = [self._feature_scales[k] for k in self.roi_pooler_in_features] |
|
|
|
num_boxes = sum(x.tensor.size(0) for x in boxes) |
|
output_size = self.roi_pooler_size |
|
point_coords = generate_regular_grid_point_coords(num_boxes, output_size, boxes[0].device) |
|
|
|
|
|
roi_features, _ = point_sample_fine_grained_features( |
|
features_list, features_scales, boxes, point_coords |
|
) |
|
return roi_features.view(num_boxes, roi_features.shape[1], output_size, output_size) |
|
|
|
def _sample_train_points(self, coarse_mask, instances): |
|
assert self.training |
|
gt_classes = cat([x.gt_classes for x in instances]) |
|
with torch.no_grad(): |
|
|
|
point_coords = get_uncertain_point_coords_with_randomness( |
|
coarse_mask, |
|
lambda logits: calculate_uncertainty(logits, gt_classes), |
|
self.mask_point_train_num_points, |
|
self.mask_point_oversample_ratio, |
|
self.mask_point_importance_sample_ratio, |
|
) |
|
|
|
proposal_boxes = [x.proposal_boxes for x in instances] |
|
cat_boxes = Boxes.cat(proposal_boxes) |
|
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) |
|
point_labels = sample_point_labels(instances, point_coords_wrt_image) |
|
return point_coords, point_labels |
|
|
|
def _point_pooler(self, features, proposal_boxes, point_coords): |
|
point_features_list = [features[k] for k in self.mask_point_in_features] |
|
point_features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] |
|
|
|
point_fine_grained_features, _ = point_sample_fine_grained_features( |
|
point_features_list, point_features_scales, proposal_boxes, point_coords |
|
) |
|
return point_fine_grained_features |
|
|
|
def _get_point_logits(self, point_fine_grained_features, point_coords, coarse_mask): |
|
coarse_features = point_sample(coarse_mask, point_coords, align_corners=False) |
|
point_logits = self.point_head(point_fine_grained_features, coarse_features) |
|
return point_logits |
|
|
|
def _subdivision_inference(self, features, mask_representations, instances): |
|
assert not self.training |
|
|
|
pred_boxes = [x.pred_boxes for x in instances] |
|
pred_classes = cat([x.pred_classes for x in instances]) |
|
|
|
mask_logits = None |
|
|
|
|
|
|
|
|
|
|
|
for _ in range(self.mask_point_subdivision_steps + 1): |
|
if mask_logits is None: |
|
point_coords = generate_regular_grid_point_coords( |
|
pred_classes.size(0), |
|
self.mask_point_subdivision_init_resolution, |
|
pred_boxes[0].device, |
|
) |
|
else: |
|
mask_logits = interpolate( |
|
mask_logits, scale_factor=2, mode="bilinear", align_corners=False |
|
) |
|
uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) |
|
point_indices, point_coords = get_uncertain_point_coords_on_grid( |
|
uncertainty_map, self.mask_point_subdivision_num_points |
|
) |
|
|
|
|
|
fine_grained_features = self._point_pooler(features, pred_boxes, point_coords) |
|
point_logits = self._get_point_logits( |
|
fine_grained_features, point_coords, mask_representations |
|
) |
|
|
|
if mask_logits is None: |
|
|
|
R, C, _ = point_logits.shape |
|
mask_logits = point_logits.reshape( |
|
R, |
|
C, |
|
self.mask_point_subdivision_init_resolution, |
|
self.mask_point_subdivision_init_resolution, |
|
) |
|
|
|
if len(pred_classes) == 0: |
|
mask_rcnn_inference(mask_logits, instances) |
|
return instances |
|
else: |
|
|
|
R, C, H, W = mask_logits.shape |
|
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) |
|
mask_logits = ( |
|
mask_logits.reshape(R, C, H * W) |
|
.scatter_(2, point_indices, point_logits) |
|
.view(R, C, H, W) |
|
) |
|
mask_rcnn_inference(mask_logits, instances) |
|
return instances |
|
|
|
|
|
@ROI_MASK_HEAD_REGISTRY.register() |
|
class ImplicitPointRendMaskHead(PointRendMaskHead): |
|
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
|
super().__init__(cfg, input_shape) |
|
|
|
def _init_roi_head(self, cfg, input_shape): |
|
assert hasattr(self, "num_params"), "Please initialize point_head first!" |
|
self.parameter_head = ConvFCHead(cfg, input_shape, output_shape=(self.num_params,)) |
|
self.regularizer = cfg.MODEL.IMPLICIT_POINTREND.PARAMS_L2_REGULARIZER |
|
|
|
def _init_point_head(self, cfg, input_shape): |
|
|
|
self.mask_point_on = True |
|
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES |
|
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES |
|
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS |
|
|
|
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS |
|
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS |
|
|
|
|
|
in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) |
|
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) |
|
self.num_params = self.point_head.num_params |
|
|
|
|
|
self.mask_point_subdivision_init_resolution = int( |
|
math.sqrt(self.mask_point_subdivision_num_points) |
|
) |
|
assert ( |
|
self.mask_point_subdivision_init_resolution |
|
* self.mask_point_subdivision_init_resolution |
|
== self.mask_point_subdivision_num_points |
|
) |
|
|
|
def forward(self, features, instances): |
|
""" |
|
Args: |
|
features (dict[str, Tensor]): a dict of image-level features |
|
instances (list[Instances]): proposals in training; detected |
|
instances in inference |
|
""" |
|
if self.training: |
|
proposal_boxes = [x.proposal_boxes for x in instances] |
|
parameters = self.parameter_head(self._roi_pooler(features, proposal_boxes)) |
|
losses = {"loss_l2": self.regularizer * (parameters ** 2).mean()} |
|
|
|
point_coords, point_labels = self._uniform_sample_train_points(instances) |
|
point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) |
|
point_logits = self._get_point_logits( |
|
point_fine_grained_features, point_coords, parameters |
|
) |
|
losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) |
|
return losses |
|
else: |
|
pred_boxes = [x.pred_boxes for x in instances] |
|
parameters = self.parameter_head(self._roi_pooler(features, pred_boxes)) |
|
return self._subdivision_inference(features, parameters, instances) |
|
|
|
def _uniform_sample_train_points(self, instances): |
|
assert self.training |
|
proposal_boxes = [x.proposal_boxes for x in instances] |
|
cat_boxes = Boxes.cat(proposal_boxes) |
|
|
|
point_coords = torch.rand( |
|
len(cat_boxes), self.mask_point_train_num_points, 2, device=cat_boxes.tensor.device |
|
) |
|
|
|
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) |
|
point_labels = sample_point_labels(instances, point_coords_wrt_image) |
|
return point_coords, point_labels |
|
|
|
def _get_point_logits(self, fine_grained_features, point_coords, parameters): |
|
return self.point_head(fine_grained_features, point_coords, parameters) |
|
|