|
|
|
import logging
|
|
import numpy as np
|
|
from typing import Dict, List, Optional, Tuple
|
|
import torch
|
|
from torch import nn
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.data.detection_utils import convert_image_to_rgb
|
|
from detectron2.layers import move_device_like
|
|
from detectron2.structures import ImageList, Instances
|
|
from detectron2.utils.events import get_event_storage
|
|
from detectron2.utils.logger import log_first_n
|
|
|
|
from ..backbone import Backbone, build_backbone
|
|
from ..postprocessing import detector_postprocess
|
|
from ..proposal_generator import build_proposal_generator
|
|
from ..roi_heads import build_roi_heads
|
|
from .build import META_ARCH_REGISTRY
|
|
|
|
__all__ = ["GeneralizedRCNN", "ProposalNetwork"]
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class GeneralizedRCNN(nn.Module):
|
|
"""
|
|
Generalized R-CNN. Any models that contains the following three components:
|
|
1. Per-image feature extraction (aka backbone)
|
|
2. Region proposal generation
|
|
3. Per-region feature extraction and prediction
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(
|
|
self,
|
|
*,
|
|
backbone: Backbone,
|
|
proposal_generator: nn.Module,
|
|
roi_heads: nn.Module,
|
|
pixel_mean: Tuple[float],
|
|
pixel_std: Tuple[float],
|
|
input_format: Optional[str] = None,
|
|
vis_period: int = 0,
|
|
):
|
|
"""
|
|
Args:
|
|
backbone: a backbone module, must follow detectron2's backbone interface
|
|
proposal_generator: a module that generates proposals using backbone features
|
|
roi_heads: a ROI head that performs per-region computation
|
|
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
|
the per-channel mean and std to be used to normalize the input image
|
|
input_format: describe the meaning of channels of input. Needed by visualization
|
|
vis_period: the period to run visualization. Set to 0 to disable.
|
|
"""
|
|
super().__init__()
|
|
self.backbone = backbone
|
|
self.proposal_generator = proposal_generator
|
|
self.roi_heads = roi_heads
|
|
|
|
self.input_format = input_format
|
|
self.vis_period = vis_period
|
|
if vis_period > 0:
|
|
assert input_format is not None, "input_format is required for visualization!"
|
|
|
|
self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
|
|
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
|
|
assert (
|
|
self.pixel_mean.shape == self.pixel_std.shape
|
|
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg):
|
|
backbone = build_backbone(cfg)
|
|
return {
|
|
"backbone": backbone,
|
|
"proposal_generator": build_proposal_generator(cfg, backbone.output_shape()),
|
|
"roi_heads": build_roi_heads(cfg, backbone.output_shape()),
|
|
"input_format": cfg.INPUT.FORMAT,
|
|
"vis_period": cfg.VIS_PERIOD,
|
|
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
|
"pixel_std": cfg.MODEL.PIXEL_STD,
|
|
}
|
|
|
|
@property
|
|
def device(self):
|
|
return self.pixel_mean.device
|
|
|
|
def _move_to_current_device(self, x):
|
|
return move_device_like(x, self.pixel_mean)
|
|
|
|
def visualize_training(self, batched_inputs, proposals):
|
|
"""
|
|
A function used to visualize images and proposals. It shows ground truth
|
|
bounding boxes on the original image and up to 20 top-scoring predicted
|
|
object proposals on the original image. Users can implement different
|
|
visualization functions for different models.
|
|
|
|
Args:
|
|
batched_inputs (list): a list that contains input to the model.
|
|
proposals (list): a list that contains predicted proposals. Both
|
|
batched_inputs and proposals should have the same length.
|
|
"""
|
|
from detectron2.utils.visualizer import Visualizer
|
|
|
|
storage = get_event_storage()
|
|
max_vis_prop = 20
|
|
|
|
for input, prop in zip(batched_inputs, proposals):
|
|
img = input["image"]
|
|
img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)
|
|
v_gt = Visualizer(img, None)
|
|
v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes)
|
|
anno_img = v_gt.get_image()
|
|
box_size = min(len(prop.proposal_boxes), max_vis_prop)
|
|
v_pred = Visualizer(img, None)
|
|
v_pred = v_pred.overlay_instances(
|
|
boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy()
|
|
)
|
|
prop_img = v_pred.get_image()
|
|
vis_img = np.concatenate((anno_img, prop_img), axis=1)
|
|
vis_img = vis_img.transpose(2, 0, 1)
|
|
vis_name = "Left: GT bounding boxes; Right: Predicted proposals"
|
|
storage.put_image(vis_name, vis_img)
|
|
break
|
|
|
|
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
|
|
"""
|
|
Args:
|
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
|
|
Each item in the list contains the inputs for one image.
|
|
For now, each item in the list is a dict that contains:
|
|
|
|
* image: Tensor, image in (C, H, W) format.
|
|
* instances (optional): groundtruth :class:`Instances`
|
|
* proposals (optional): :class:`Instances`, precomputed proposals.
|
|
|
|
Other information that's included in the original dicts, such as:
|
|
|
|
* "height", "width" (int): the output resolution of the model, used in inference.
|
|
See :meth:`postprocess` for details.
|
|
|
|
Returns:
|
|
list[dict]:
|
|
Each dict is the output for one input image.
|
|
The dict contains one key "instances" whose value is a :class:`Instances`.
|
|
The :class:`Instances` object has the following keys:
|
|
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
|
|
"""
|
|
if not self.training:
|
|
return self.inference(batched_inputs)
|
|
|
|
images = self.preprocess_image(batched_inputs)
|
|
if "instances" in batched_inputs[0]:
|
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
|
else:
|
|
gt_instances = None
|
|
|
|
features = self.backbone(images.tensor)
|
|
|
|
if self.proposal_generator is not None:
|
|
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
|
|
else:
|
|
assert "proposals" in batched_inputs[0]
|
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
|
proposal_losses = {}
|
|
|
|
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
|
|
if self.vis_period > 0:
|
|
storage = get_event_storage()
|
|
if storage.iter % self.vis_period == 0:
|
|
self.visualize_training(batched_inputs, proposals)
|
|
|
|
losses = {}
|
|
losses.update(detector_losses)
|
|
losses.update(proposal_losses)
|
|
return losses
|
|
|
|
def inference(
|
|
self,
|
|
batched_inputs: List[Dict[str, torch.Tensor]],
|
|
detected_instances: Optional[List[Instances]] = None,
|
|
do_postprocess: bool = True,
|
|
):
|
|
"""
|
|
Run inference on the given inputs.
|
|
|
|
Args:
|
|
batched_inputs (list[dict]): same as in :meth:`forward`
|
|
detected_instances (None or list[Instances]): if not None, it
|
|
contains an `Instances` object per image. The `Instances`
|
|
object contains "pred_boxes" and "pred_classes" which are
|
|
known boxes in the image.
|
|
The inference will then skip the detection of bounding boxes,
|
|
and only predict other per-ROI outputs.
|
|
do_postprocess (bool): whether to apply post-processing on the outputs.
|
|
|
|
Returns:
|
|
When do_postprocess=True, same as in :meth:`forward`.
|
|
Otherwise, a list[Instances] containing raw network outputs.
|
|
"""
|
|
assert not self.training
|
|
|
|
images = self.preprocess_image(batched_inputs)
|
|
features = self.backbone(images.tensor)
|
|
|
|
if detected_instances is None:
|
|
if self.proposal_generator is not None:
|
|
proposals, _ = self.proposal_generator(images, features, None)
|
|
else:
|
|
assert "proposals" in batched_inputs[0]
|
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
|
|
|
|
results, _ = self.roi_heads(images, features, proposals, None)
|
|
else:
|
|
detected_instances = [x.to(self.device) for x in detected_instances]
|
|
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
|
|
|
|
if do_postprocess:
|
|
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
|
|
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
|
|
return results
|
|
|
|
def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
|
|
"""
|
|
Normalize, pad and batch the input images.
|
|
"""
|
|
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
|
images = ImageList.from_tensors(
|
|
images,
|
|
self.backbone.size_divisibility,
|
|
padding_constraints=self.backbone.padding_constraints,
|
|
)
|
|
return images
|
|
|
|
@staticmethod
|
|
def _postprocess(instances, batched_inputs: List[Dict[str, torch.Tensor]], image_sizes):
|
|
"""
|
|
Rescale the output instances to the target size.
|
|
"""
|
|
|
|
processed_results = []
|
|
for results_per_image, input_per_image, image_size in zip(
|
|
instances, batched_inputs, image_sizes
|
|
):
|
|
height = input_per_image.get("height", image_size[0])
|
|
width = input_per_image.get("width", image_size[1])
|
|
r = detector_postprocess(results_per_image, height, width)
|
|
processed_results.append({"instances": r})
|
|
return processed_results
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class ProposalNetwork(nn.Module):
|
|
"""
|
|
A meta architecture that only predicts object proposals.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(
|
|
self,
|
|
*,
|
|
backbone: Backbone,
|
|
proposal_generator: nn.Module,
|
|
pixel_mean: Tuple[float],
|
|
pixel_std: Tuple[float],
|
|
):
|
|
"""
|
|
Args:
|
|
backbone: a backbone module, must follow detectron2's backbone interface
|
|
proposal_generator: a module that generates proposals using backbone features
|
|
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
|
the per-channel mean and std to be used to normalize the input image
|
|
"""
|
|
super().__init__()
|
|
self.backbone = backbone
|
|
self.proposal_generator = proposal_generator
|
|
self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
|
|
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg):
|
|
backbone = build_backbone(cfg)
|
|
return {
|
|
"backbone": backbone,
|
|
"proposal_generator": build_proposal_generator(cfg, backbone.output_shape()),
|
|
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
|
"pixel_std": cfg.MODEL.PIXEL_STD,
|
|
}
|
|
|
|
@property
|
|
def device(self):
|
|
return self.pixel_mean.device
|
|
|
|
def _move_to_current_device(self, x):
|
|
return move_device_like(x, self.pixel_mean)
|
|
|
|
def forward(self, batched_inputs):
|
|
"""
|
|
Args:
|
|
Same as in :class:`GeneralizedRCNN.forward`
|
|
|
|
Returns:
|
|
list[dict]:
|
|
Each dict is the output for one input image.
|
|
The dict contains one key "proposals" whose value is a
|
|
:class:`Instances` with keys "proposal_boxes" and "objectness_logits".
|
|
"""
|
|
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
|
images = ImageList.from_tensors(
|
|
images,
|
|
self.backbone.size_divisibility,
|
|
padding_constraints=self.backbone.padding_constraints,
|
|
)
|
|
features = self.backbone(images.tensor)
|
|
|
|
if "instances" in batched_inputs[0]:
|
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
|
elif "targets" in batched_inputs[0]:
|
|
log_first_n(
|
|
logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
|
|
)
|
|
gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
|
|
else:
|
|
gt_instances = None
|
|
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
|
|
|
|
|
|
if self.training:
|
|
return proposal_losses
|
|
|
|
processed_results = []
|
|
for results_per_image, input_per_image, image_size in zip(
|
|
proposals, batched_inputs, images.image_sizes
|
|
):
|
|
height = input_per_image.get("height", image_size[0])
|
|
width = input_per_image.get("width", image_size[1])
|
|
r = detector_postprocess(results_per_image, height, width)
|
|
processed_results.append({"proposals": r})
|
|
return processed_results
|
|
|