Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from collections import defaultdict | |
from typing import Dict, List, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import Conv2d | |
from mmcv.ops import point_sample | |
from mmengine.model import ModuleList | |
from mmengine.model.weight_init import caffe2_xavier_init | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.models.dense_heads import AnchorFreeHead, MaskFormerHead | |
from mmdet.models.utils import get_uncertain_point_coords_with_randomness | |
from mmdet.registry import MODELS, TASK_UTILS | |
from mmdet.structures import TrackDataSample, TrackSampleList | |
from mmdet.structures.mask import mask2bbox | |
from mmdet.utils import (ConfigType, InstanceList, OptConfigType, | |
OptMultiConfig, reduce_mean) | |
from ..layers import Mask2FormerTransformerDecoder | |
class Mask2FormerTrackHead(MaskFormerHead): | |
"""Implements the Mask2Former head. | |
See `Masked-attention Mask Transformer for Universal Image | |
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. | |
Args: | |
in_channels (list[int]): Number of channels in the input feature map. | |
feat_channels (int): Number of channels for features. | |
out_channels (int): Number of channels for output. | |
num_classes (int): Number of VIS classes. | |
num_queries (int): Number of query in Transformer decoder. | |
Defaults to 100. | |
num_transformer_feat_level (int): Number of feats levels. | |
Defaults to 3. | |
pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel | |
decoder. | |
enforce_decoder_input_project (bool, optional): Whether to add | |
a layer to change the embed_dim of transformer encoder in | |
pixel decoder to the embed_dim of transformer decoder. | |
Defaults to False. | |
transformer_decoder (:obj:`ConfigDict` or dict): Config for | |
transformer decoder. | |
positional_encoding (:obj:`ConfigDict` or dict): Config for | |
transformer decoder position encoding. | |
Defaults to `SinePositionalEncoding3D`. | |
loss_cls (:obj:`ConfigDict` or dict): Config of the classification | |
loss. Defaults to `CrossEntropyLoss`. | |
loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. | |
Defaults to 'CrossEntropyLoss'. | |
loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. | |
Defaults to 'DiceLoss'. | |
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of | |
Mask2Former head. Defaults to None. | |
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
Mask2Former head. Defaults to None. | |
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ | |
dict], optional): Initialization config dict. Defaults to None. | |
""" | |
def __init__(self, | |
in_channels: List[int], | |
feat_channels: int, | |
out_channels: int, | |
num_classes: int, | |
num_frames: int = 2, | |
num_queries: int = 100, | |
num_transformer_feat_level: int = 3, | |
pixel_decoder: ConfigType = ..., | |
enforce_decoder_input_project: bool = False, | |
transformer_decoder: ConfigType = ..., | |
positional_encoding: ConfigType = dict( | |
num_feats=128, normalize=True), | |
loss_cls: ConfigType = dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=2.0, | |
reduction='mean', | |
class_weight=[1.0] * 133 + [0.1]), | |
loss_mask: ConfigType = dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=True, | |
reduction='mean', | |
loss_weight=5.0), | |
loss_dice: ConfigType = dict( | |
type='DiceLoss', | |
use_sigmoid=True, | |
activate=True, | |
reduction='mean', | |
naive_dice=True, | |
eps=1.0, | |
loss_weight=5.0), | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
init_cfg: OptMultiConfig = None, | |
**kwargs) -> None: | |
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) | |
self.num_classes = num_classes | |
self.num_frames = num_frames | |
self.num_queries = num_queries | |
self.num_transformer_feat_level = num_transformer_feat_level | |
self.num_transformer_feat_level = num_transformer_feat_level | |
self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads | |
self.num_transformer_decoder_layers = transformer_decoder.num_layers | |
assert pixel_decoder.encoder.layer_cfg. \ | |
self_attn_cfg.num_levels == num_transformer_feat_level | |
pixel_decoder_ = copy.deepcopy(pixel_decoder) | |
pixel_decoder_.update( | |
in_channels=in_channels, | |
feat_channels=feat_channels, | |
out_channels=out_channels) | |
self.pixel_decoder = MODELS.build(pixel_decoder_) | |
self.transformer_decoder = Mask2FormerTransformerDecoder( | |
**transformer_decoder) | |
self.decoder_embed_dims = self.transformer_decoder.embed_dims | |
self.decoder_input_projs = ModuleList() | |
# from low resolution to high resolution | |
for _ in range(num_transformer_feat_level): | |
if (self.decoder_embed_dims != feat_channels | |
or enforce_decoder_input_project): | |
self.decoder_input_projs.append( | |
Conv2d( | |
feat_channels, self.decoder_embed_dims, kernel_size=1)) | |
else: | |
self.decoder_input_projs.append(nn.Identity()) | |
self.decoder_positional_encoding = MODELS.build(positional_encoding) | |
self.query_embed = nn.Embedding(self.num_queries, feat_channels) | |
self.query_feat = nn.Embedding(self.num_queries, feat_channels) | |
# from low resolution to high resolution | |
self.level_embed = nn.Embedding(self.num_transformer_feat_level, | |
feat_channels) | |
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |
self.mask_embed = nn.Sequential( | |
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
nn.Linear(feat_channels, out_channels)) | |
self.test_cfg = test_cfg | |
self.train_cfg = train_cfg | |
if train_cfg: | |
self.assigner = TASK_UTILS.build(self.train_cfg.assigner) | |
self.sampler = TASK_UTILS.build( | |
# self.train_cfg.sampler, default_args=dict(context=self)) | |
self.train_cfg['sampler'], | |
default_args=dict(context=self)) | |
self.num_points = self.train_cfg.get('num_points', 12544) | |
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) | |
self.importance_sample_ratio = self.train_cfg.get( | |
'importance_sample_ratio', 0.75) | |
self.class_weight = loss_cls.class_weight | |
self.loss_cls = MODELS.build(loss_cls) | |
self.loss_mask = MODELS.build(loss_mask) | |
self.loss_dice = MODELS.build(loss_dice) | |
def init_weights(self) -> None: | |
for m in self.decoder_input_projs: | |
if isinstance(m, Conv2d): | |
caffe2_xavier_init(m, bias=0) | |
self.pixel_decoder.init_weights() | |
for p in self.transformer_decoder.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_normal_(p) | |
def preprocess_gt(self, batch_gt_instances: InstanceList) -> InstanceList: | |
"""Preprocess the ground truth for all images. | |
It aims to reorganize the `gt`. For example, in the | |
`batch_data_sample.gt_instances.mask`, its shape is | |
`(all_num_gts, h, w)`, but we don't know each gt belongs to which `img` | |
(assume `num_frames` is 2). So, this func used to reshape the `gt_mask` | |
to `(num_gts_per_img, num_frames, h, w)`. In addition, we can't | |
guarantee that the number of instances in these two images is equal, | |
so `-1` refers to nonexistent instances. | |
Args: | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``labels``, each is | |
ground truth labels of each bbox, with shape (num_gts, ) | |
and ``masks``, each is ground truth masks of each instances | |
of an image, shape (num_gts, h, w). | |
Returns: | |
list[obj:`InstanceData`]: each contains the following keys | |
- labels (Tensor): Ground truth class indices\ | |
for an image, with shape (n, ), n is the sum of\ | |
number of stuff type and number of instance in an image. | |
- masks (Tensor): Ground truth mask for a\ | |
image, with shape (n, t, h, w). | |
""" | |
final_batch_gt_instances = [] | |
batch_size = len(batch_gt_instances) // self.num_frames | |
for batch_idx in range(batch_size): | |
pair_gt_insatences = batch_gt_instances[batch_idx * | |
self.num_frames:batch_idx * | |
self.num_frames + | |
self.num_frames] | |
assert len( | |
pair_gt_insatences | |
) > 1, f'mask2former for vis need multi frames to train, \ | |
but you only use {len(pair_gt_insatences)} frames' | |
_device = pair_gt_insatences[0].labels.device | |
for gt_instances in pair_gt_insatences: | |
gt_instances.masks = gt_instances.masks.to_tensor( | |
dtype=torch.bool, device=_device) | |
all_ins_id = torch.cat([ | |
gt_instances.instances_ids | |
for gt_instances in pair_gt_insatences | |
]) | |
all_ins_id = all_ins_id.unique().tolist() | |
map_ins_id = dict() | |
for i, ins_id in enumerate(all_ins_id): | |
map_ins_id[ins_id] = i | |
num_instances = len(all_ins_id) | |
mask_shape = [ | |
num_instances, self.num_frames, | |
pair_gt_insatences[0].masks.shape[1], | |
pair_gt_insatences[0].masks.shape[2] | |
] | |
gt_masks_per_video = torch.zeros( | |
mask_shape, dtype=torch.bool, device=_device) | |
gt_ids_per_video = torch.full((num_instances, self.num_frames), | |
-1, | |
dtype=torch.long, | |
device=_device) | |
gt_labels_per_video = torch.full((num_instances, ), | |
-1, | |
dtype=torch.long, | |
device=_device) | |
for frame_id in range(self.num_frames): | |
cur_frame_gts = pair_gt_insatences[frame_id] | |
ins_ids = cur_frame_gts.instances_ids.tolist() | |
for i, id in enumerate(ins_ids): | |
gt_masks_per_video[map_ins_id[id], | |
frame_id, :, :] = cur_frame_gts.masks[i] | |
gt_ids_per_video[map_ins_id[id], | |
frame_id] = cur_frame_gts.instances_ids[i] | |
gt_labels_per_video[ | |
map_ins_id[id]] = cur_frame_gts.labels[i] | |
tmp_instances = InstanceData( | |
labels=gt_labels_per_video, | |
masks=gt_masks_per_video.long(), | |
instances_id=gt_ids_per_video) | |
final_batch_gt_instances.append(tmp_instances) | |
return final_batch_gt_instances | |
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, | |
gt_instances: InstanceData, | |
img_meta: dict) -> Tuple[Tensor]: | |
"""Compute classification and mask targets for one image. | |
Args: | |
cls_score (Tensor): Mask score logits from a single decoder layer | |
for one image. Shape (num_queries, cls_out_channels). | |
mask_pred (Tensor): Mask logits for a single decoder layer for one | |
image. Shape (num_queries, num_frames, h, w). | |
gt_instances (:obj:`InstanceData`): It contains ``labels`` and | |
``masks``. | |
img_meta (dict): Image informtation. | |
Returns: | |
tuple[Tensor]: A tuple containing the following for one image. | |
- labels (Tensor): Labels of each image. \ | |
shape (num_queries, ). | |
- label_weights (Tensor): Label weights of each image. \ | |
shape (num_queries, ). | |
- mask_targets (Tensor): Mask targets of each image. \ | |
shape (num_queries, num_frames, h, w). | |
- mask_weights (Tensor): Mask weights of each image. \ | |
shape (num_queries, ). | |
- pos_inds (Tensor): Sampled positive indices for each \ | |
image. | |
- neg_inds (Tensor): Sampled negative indices for each \ | |
image. | |
- sampling_result (:obj:`SamplingResult`): Sampling results. | |
""" | |
# (num_gts, ) | |
gt_labels = gt_instances.labels | |
# (num_gts, num_frames, h, w) | |
gt_masks = gt_instances.masks | |
# sample points | |
num_queries = cls_score.shape[0] | |
num_gts = gt_labels.shape[0] | |
point_coords = torch.rand((1, self.num_points, 2), | |
device=cls_score.device) | |
# shape (num_queries, num_points) | |
mask_points_pred = point_sample(mask_pred, | |
point_coords.repeat(num_queries, 1, | |
1)).flatten(1) | |
# shape (num_gts, num_points) | |
gt_points_masks = point_sample(gt_masks.float(), | |
point_coords.repeat(num_gts, 1, | |
1)).flatten(1) | |
sampled_gt_instances = InstanceData( | |
labels=gt_labels, masks=gt_points_masks) | |
sampled_pred_instances = InstanceData( | |
scores=cls_score, masks=mask_points_pred) | |
# assign and sample | |
assign_result = self.assigner.assign( | |
pred_instances=sampled_pred_instances, | |
gt_instances=sampled_gt_instances, | |
img_meta=img_meta) | |
pred_instances = InstanceData(scores=cls_score, masks=mask_pred) | |
sampling_result = self.sampler.sample( | |
assign_result=assign_result, | |
pred_instances=pred_instances, | |
gt_instances=gt_instances) | |
pos_inds = sampling_result.pos_inds | |
neg_inds = sampling_result.neg_inds | |
# label target | |
labels = gt_labels.new_full((self.num_queries, ), | |
self.num_classes, | |
dtype=torch.long) | |
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] | |
label_weights = gt_labels.new_ones((self.num_queries, )) | |
# mask target | |
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] | |
mask_weights = mask_pred.new_zeros((self.num_queries, )) | |
mask_weights[pos_inds] = 1.0 | |
return (labels, label_weights, mask_targets, mask_weights, pos_inds, | |
neg_inds, sampling_result) | |
def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, | |
batch_gt_instances: List[InstanceData], | |
batch_img_metas: List[dict]) -> Tuple[Tensor]: | |
"""Loss function for outputs from a single decoder layer. | |
Args: | |
cls_scores (Tensor): Mask score logits from a single decoder layer | |
for all images. Shape (batch_size, num_queries, | |
cls_out_channels). Note `cls_out_channels` should include | |
background. | |
mask_preds (Tensor): Mask logits for a pixel decoder for all | |
images. Shape (batch_size, num_queries, num_frames,h, w). | |
batch_gt_instances (list[obj:`InstanceData`]): each contains | |
``labels`` and ``masks``. | |
batch_img_metas (list[dict]): List of image meta information. | |
Returns: | |
tuple[Tensor]: Loss components for outputs from a single \ | |
decoder layer. | |
""" | |
num_imgs = cls_scores.size(0) | |
cls_scores_list = [cls_scores[i] for i in range(num_imgs)] | |
mask_preds_list = [mask_preds[i] for i in range(num_imgs)] | |
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |
avg_factor) = self.get_targets(cls_scores_list, mask_preds_list, | |
batch_gt_instances, batch_img_metas) | |
# shape (batch_size, num_queries) | |
labels = torch.stack(labels_list, dim=0) | |
# shape (batch_size, num_queries) | |
label_weights = torch.stack(label_weights_list, dim=0) | |
# shape (num_total_gts, num_frames, h, w) | |
mask_targets = torch.cat(mask_targets_list, dim=0) | |
# shape (batch_size, num_queries) | |
mask_weights = torch.stack(mask_weights_list, dim=0) | |
# classfication loss | |
# shape (batch_size * num_queries, ) | |
cls_scores = cls_scores.flatten(0, 1) | |
labels = labels.flatten(0, 1) | |
label_weights = label_weights.flatten(0, 1) | |
class_weight = cls_scores.new_tensor(self.class_weight) | |
loss_cls = self.loss_cls( | |
cls_scores, | |
labels, | |
label_weights, | |
avg_factor=class_weight[labels].sum()) | |
num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) | |
num_total_masks = max(num_total_masks, 1) | |
# extract positive ones | |
# shape (batch_size, num_queries, num_frames, h, w) | |
# -> (num_total_gts, num_frames, h, w) | |
mask_preds = mask_preds[mask_weights > 0] | |
if mask_targets.shape[0] == 0: | |
# zero match | |
loss_dice = mask_preds.sum() | |
loss_mask = mask_preds.sum() | |
return loss_cls, loss_mask, loss_dice | |
with torch.no_grad(): | |
points_coords = get_uncertain_point_coords_with_randomness( | |
mask_preds.flatten(0, 1).unsqueeze(1), None, self.num_points, | |
self.oversample_ratio, self.importance_sample_ratio) | |
# shape (num_total_gts * num_frames, h, w) -> | |
# (num_total_gts, num_points) | |
mask_point_targets = point_sample( | |
mask_targets.flatten(0, 1).unsqueeze(1).float(), | |
points_coords).squeeze(1) | |
# shape (num_total_gts * num_frames, num_points) | |
mask_point_preds = point_sample( | |
mask_preds.flatten(0, 1).unsqueeze(1), points_coords).squeeze(1) | |
# dice loss | |
loss_dice = self.loss_dice( | |
mask_point_preds, mask_point_targets, avg_factor=num_total_masks) | |
# mask loss | |
# shape (num_total_gts * num_frames, num_points) -> | |
# (num_total_gts * num_frames * num_points, ) | |
mask_point_preds = mask_point_preds.reshape(-1) | |
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) | |
mask_point_targets = mask_point_targets.reshape(-1) | |
loss_mask = self.loss_mask( | |
mask_point_preds, | |
mask_point_targets, | |
avg_factor=num_total_masks * self.num_points / self.num_frames) | |
return loss_cls, loss_mask, loss_dice | |
def _forward_head( | |
self, decoder_out: Tensor, mask_feature: Tensor, | |
attn_mask_target_size: Tuple[int, | |
int]) -> Tuple[Tensor, Tensor, Tensor]: | |
"""Forward for head part which is called after every decoder layer. | |
Args: | |
decoder_out (Tensor): in shape (num_queries, batch_size, c). | |
mask_feature (Tensor): in shape (batch_size, t, c, h, w). | |
attn_mask_target_size (tuple[int, int]): target attention | |
mask size. | |
Returns: | |
tuple: A tuple contain three elements. | |
- cls_pred (Tensor): Classification scores in shape \ | |
(batch_size, num_queries, cls_out_channels). \ | |
Note `cls_out_channels` should include background. | |
- mask_pred (Tensor): Mask scores in shape \ | |
(batch_size, num_queries,h, w). | |
- attn_mask (Tensor): Attention mask in shape \ | |
(batch_size * num_heads, num_queries, h, w). | |
""" | |
decoder_out = self.transformer_decoder.post_norm(decoder_out) | |
cls_pred = self.cls_embed(decoder_out) | |
mask_embed = self.mask_embed(decoder_out) | |
# shape (batch_size, num_queries, t, h, w) | |
mask_pred = torch.einsum('bqc,btchw->bqthw', mask_embed, mask_feature) | |
b, q, t, _, _ = mask_pred.shape | |
attn_mask = F.interpolate( | |
mask_pred.flatten(0, 1), | |
attn_mask_target_size, | |
mode='bilinear', | |
align_corners=False).view(b, q, t, attn_mask_target_size[0], | |
attn_mask_target_size[1]) | |
# shape (batch_size, num_queries, t, h, w) -> | |
# (batch_size, num_queries, t*h*w) -> | |
# (batch_size, num_head, num_queries, t*h*w) -> | |
# (batch_size*num_head, num_queries, t*h*w) | |
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( | |
(1, self.num_heads, 1, 1)).flatten(0, 1) | |
attn_mask = attn_mask.sigmoid() < 0.5 | |
attn_mask = attn_mask.detach() | |
return cls_pred, mask_pred, attn_mask | |
def forward( | |
self, x: List[Tensor], data_samples: TrackDataSample | |
) -> Tuple[List[Tensor], List[Tensor]]: | |
"""Forward function. | |
Args: | |
x (list[Tensor]): Multi scale Features from the | |
upstream network, each is a 4D-tensor. | |
data_samples (List[:obj:`TrackDataSample`]): The Data | |
Samples. It usually includes information such as `gt_instance`. | |
Returns: | |
tuple[list[Tensor]]: A tuple contains two elements. | |
- cls_pred_list (list[Tensor)]: Classification logits \ | |
for each decoder layer. Each is a 3D-tensor with shape \ | |
(batch_size, num_queries, cls_out_channels). \ | |
Note `cls_out_channels` should include background. | |
- mask_pred_list (list[Tensor]): Mask logits for each \ | |
decoder layer. Each with shape (batch_size, num_queries, \ | |
h, w). | |
""" | |
mask_features, multi_scale_memorys = self.pixel_decoder(x) | |
bt, c_m, h_m, w_m = mask_features.shape | |
batch_size = bt // self.num_frames if self.training else 1 | |
t = bt // batch_size | |
mask_features = mask_features.view(batch_size, t, c_m, h_m, w_m) | |
# multi_scale_memorys (from low resolution to high resolution) | |
decoder_inputs = [] | |
decoder_positional_encodings = [] | |
for i in range(self.num_transformer_feat_level): | |
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) | |
decoder_input = decoder_input.flatten(2) | |
level_embed = self.level_embed.weight[i][None, :, None] | |
decoder_input = decoder_input + level_embed | |
_, c, hw = decoder_input.shape | |
# shape (batch_size*t, c, h, w) -> | |
# (batch_size, t, c, hw) -> | |
# (batch_size, t*h*w, c) | |
decoder_input = decoder_input.view(batch_size, t, c, | |
hw).permute(0, 1, 3, | |
2).flatten(1, 2) | |
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |
mask = decoder_input.new_zeros( | |
(batch_size, t) + multi_scale_memorys[i].shape[-2:], | |
dtype=torch.bool) | |
decoder_positional_encoding = self.decoder_positional_encoding( | |
mask) | |
decoder_positional_encoding = decoder_positional_encoding.flatten( | |
3).permute(0, 1, 3, 2).flatten(1, 2) | |
decoder_inputs.append(decoder_input) | |
decoder_positional_encodings.append(decoder_positional_encoding) | |
# shape (num_queries, c) -> (batch_size, num_queries, c) | |
query_feat = self.query_feat.weight.unsqueeze(0).repeat( | |
(batch_size, 1, 1)) | |
query_embed = self.query_embed.weight.unsqueeze(0).repeat( | |
(batch_size, 1, 1)) | |
cls_pred_list = [] | |
mask_pred_list = [] | |
cls_pred, mask_pred, attn_mask = self._forward_head( | |
query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) | |
cls_pred_list.append(cls_pred) | |
mask_pred_list.append(mask_pred) | |
for i in range(self.num_transformer_decoder_layers): | |
level_idx = i % self.num_transformer_feat_level | |
# if a mask is all True(all background), then set it all False. | |
attn_mask[torch.where( | |
attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
# cross_attn + self_attn | |
layer = self.transformer_decoder.layers[i] | |
query_feat = layer( | |
query=query_feat, | |
key=decoder_inputs[level_idx], | |
value=decoder_inputs[level_idx], | |
query_pos=query_embed, | |
key_pos=decoder_positional_encodings[level_idx], | |
cross_attn_mask=attn_mask, | |
query_key_padding_mask=None, | |
# here we do not apply masking on padded region | |
key_padding_mask=None) | |
cls_pred, mask_pred, attn_mask = self._forward_head( | |
query_feat, mask_features, multi_scale_memorys[ | |
(i + 1) % self.num_transformer_feat_level].shape[-2:]) | |
cls_pred_list.append(cls_pred) | |
mask_pred_list.append(mask_pred) | |
return cls_pred_list, mask_pred_list | |
def loss( | |
self, | |
x: Tuple[Tensor], | |
data_samples: TrackSampleList, | |
) -> Dict[str, Tensor]: | |
"""Perform forward propagation and loss calculation of the track head | |
on the features of the upstream network. | |
Args: | |
x (tuple[Tensor]): Multi-level features from the upstream | |
network, each is a 4D-tensor. | |
data_samples (List[:obj:`TrackDataSample`]): The Data | |
Samples. It usually includes information such as `gt_instance`. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
batch_img_metas = [] | |
batch_gt_instances = [] | |
for data_sample in data_samples: | |
video_img_metas = defaultdict(list) | |
for image_idx in range(len(data_sample)): | |
batch_gt_instances.append(data_sample[image_idx].gt_instances) | |
for key, value in data_sample[image_idx].metainfo.items(): | |
video_img_metas[key].append(value) | |
batch_img_metas.append(video_img_metas) | |
# forward | |
all_cls_scores, all_mask_preds = self(x, data_samples) | |
# preprocess ground truth | |
batch_gt_instances = self.preprocess_gt(batch_gt_instances) | |
# loss | |
losses = self.loss_by_feat(all_cls_scores, all_mask_preds, | |
batch_gt_instances, batch_img_metas) | |
return losses | |
def predict(self, | |
x: Tuple[Tensor], | |
data_samples: TrackDataSample, | |
rescale: bool = True) -> InstanceList: | |
"""Test without augmentation. | |
Args: | |
x (tuple[Tensor]): Multi-level features from the | |
upstream network, each is a 4D-tensor. | |
data_samples (List[:obj:`TrackDataSample`]): The Data | |
Samples. It usually includes information such as `gt_instance`. | |
rescale (bool, Optional): If False, then returned bboxes and masks | |
will fit the scale of img, otherwise, returned bboxes and masks | |
will fit the scale of original image shape. Defaults to True. | |
Returns: | |
list[obj:`InstanceData`]: each contains the following keys | |
- labels (Tensor): Prediction class indices\ | |
for an image, with shape (n, ), n is the sum of\ | |
number of stuff type and number of instance in an image. | |
- masks (Tensor): Prediction mask for a\ | |
image, with shape (n, t, h, w). | |
""" | |
batch_img_metas = [ | |
data_samples[img_idx].metainfo | |
for img_idx in range(len(data_samples)) | |
] | |
all_cls_scores, all_mask_preds = self(x, data_samples) | |
mask_cls_results = all_cls_scores[-1] | |
mask_pred_results = all_mask_preds[-1] | |
mask_cls_results = mask_cls_results[0] | |
# upsample masks | |
img_shape = batch_img_metas[0]['batch_input_shape'] | |
mask_pred_results = F.interpolate( | |
mask_pred_results[0], | |
size=(img_shape[0], img_shape[1]), | |
mode='bilinear', | |
align_corners=False) | |
results = self.predict_by_feat(mask_cls_results, mask_pred_results, | |
batch_img_metas) | |
return results | |
def predict_by_feat(self, | |
mask_cls_results: List[Tensor], | |
mask_pred_results: List[Tensor], | |
batch_img_metas: List[dict], | |
rescale: bool = True) -> InstanceList: | |
"""Get top-10 predictions. | |
Args: | |
mask_cls_results (Tensor): Mask classification logits,\ | |
shape (batch_size, num_queries, cls_out_channels). | |
Note `cls_out_channels` should include background. | |
mask_pred_results (Tensor): Mask logits, shape \ | |
(batch_size, num_queries, h, w). | |
batch_img_metas (list[dict]): List of image meta information. | |
rescale (bool, Optional): If False, then returned bboxes and masks | |
will fit the scale of img, otherwise, returned bboxes and masks | |
will fit the scale of original image shape. Defaults to True. | |
Returns: | |
list[obj:`InstanceData`]: each contains the following keys | |
- labels (Tensor): Prediction class indices\ | |
for an image, with shape (n, ), n is the sum of\ | |
number of stuff type and number of instance in an image. | |
- masks (Tensor): Prediction mask for a\ | |
image, with shape (n, t, h, w). | |
""" | |
results = [] | |
if len(mask_cls_results) > 0: | |
scores = F.softmax(mask_cls_results, dim=-1)[:, :-1] | |
labels = torch.arange(self.num_classes).unsqueeze(0).repeat( | |
self.num_queries, 1).flatten(0, 1).to(scores.device) | |
# keep top-10 predictions | |
scores_per_image, topk_indices = scores.flatten(0, 1).topk( | |
10, sorted=False) | |
labels_per_image = labels[topk_indices] | |
topk_indices = topk_indices // self.num_classes | |
mask_pred_results = mask_pred_results[topk_indices] | |
img_shape = batch_img_metas[0]['img_shape'] | |
mask_pred_results = \ | |
mask_pred_results[:, :, :img_shape[0], :img_shape[1]] | |
if rescale: | |
# return result in original resolution | |
ori_height, ori_width = batch_img_metas[0]['ori_shape'][:2] | |
mask_pred_results = F.interpolate( | |
mask_pred_results, | |
size=(ori_height, ori_width), | |
mode='bilinear', | |
align_corners=False) | |
masks = mask_pred_results > 0. | |
# format top-10 predictions | |
for img_idx in range(len(batch_img_metas)): | |
pred_track_instances = InstanceData() | |
pred_track_instances.masks = masks[:, img_idx] | |
pred_track_instances.bboxes = mask2bbox(masks[:, img_idx]) | |
pred_track_instances.labels = labels_per_image | |
pred_track_instances.scores = scores_per_image | |
pred_track_instances.instances_id = torch.arange(10) | |
results.append(pred_track_instances) | |
return results | |