|
from typing import Tuple |
|
import math |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import configurable |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head |
|
from detectron2.modeling.backbone import Backbone |
|
from detectron2.structures import Boxes, ImageList, Instances, BitMasks |
|
from detectron2.utils.memory import retry_if_cuda_oom |
|
|
|
from mask2former.modeling.criterion import SetCriterion |
|
from mask2former.modeling.matcher import HungarianMatcher |
|
from .modeling.avism_criterion import AvismSetCriterion |
|
from .modeling.avism_matcher import AvismHungarianMatcher |
|
from .modeling.transformer_decoder.avism_coco import Avism_COCO |
|
|
|
|
|
@META_ARCH_REGISTRY.register() |
|
class AVISM_COCO(nn.Module): |
|
""" |
|
Main class for mask classification semantic segmentation architectures. |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
backbone: Backbone, |
|
sem_seg_head: nn.Module, |
|
criterion: nn.Module, |
|
num_queries: int, |
|
object_mask_threshold: float, |
|
overlap_threshold: float, |
|
metadata, |
|
size_divisibility: int, |
|
pixel_mean: Tuple[float], |
|
pixel_std: Tuple[float], |
|
|
|
test_topk_per_image: int, |
|
|
|
avism_module: nn.Module, |
|
avism_criterion: nn.Module, |
|
num_frames: int, |
|
num_classes: int, |
|
is_multi_cls: bool, |
|
apply_cls_thres: float, |
|
freeze_detector: bool, |
|
test_run_chunk_size: int, |
|
test_interpolate_chunk_size: int, |
|
is_coco: bool, |
|
): |
|
""" |
|
Args: |
|
backbone: a backbone module, must follow detectron2's backbone interface |
|
sem_seg_head: a module that predicts semantic segmentation from backbone features |
|
criterion: a module that defines the loss |
|
num_queries: int, number of queries |
|
object_mask_threshold: float, threshold to filter query based on classification score |
|
for panoptic segmentation inference |
|
overlap_threshold: overlap threshold used in general inference for panoptic segmentation |
|
metadata: dataset meta, get `thing` and `stuff` category names for panoptic |
|
segmentation inference |
|
size_divisibility: Some backbones require the input height and width to be divisible by a |
|
specific integer. We can use this to override such requirement. |
|
pixel_mean, pixel_std: list or tuple with #channels element, representing |
|
the per-channel mean and std to be used to normalize the input image |
|
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image |
|
""" |
|
super().__init__() |
|
self.backbone = backbone |
|
self.sem_seg_head = sem_seg_head |
|
self.criterion = criterion |
|
self.num_queries = num_queries |
|
self.overlap_threshold = overlap_threshold |
|
self.object_mask_threshold = object_mask_threshold |
|
self.metadata = metadata |
|
if size_divisibility < 0: |
|
|
|
size_divisibility = self.backbone.size_divisibility |
|
self.size_divisibility = size_divisibility |
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) |
|
|
|
|
|
self.test_topk_per_image = test_topk_per_image |
|
|
|
|
|
self.num_frames = num_frames |
|
self.num_classes = num_classes |
|
self.vita_module = avism_module |
|
self.vita_criterion = avism_criterion |
|
self.is_multi_cls = is_multi_cls |
|
self.apply_cls_thres = apply_cls_thres |
|
|
|
if freeze_detector: |
|
for name, p in self.named_parameters(): |
|
if not "vita_module" in name: |
|
p.requires_grad_(False) |
|
self.test_run_chunk_size = test_run_chunk_size |
|
self.test_interpolate_chunk_size = test_interpolate_chunk_size |
|
|
|
self.is_coco = is_coco |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
backbone = build_backbone(cfg) |
|
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) |
|
|
|
|
|
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION |
|
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT |
|
avism_deep_supervision = cfg.MODEL.AVISM.DEEP_SUPERVISION |
|
|
|
|
|
class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT |
|
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT |
|
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT |
|
sim_weight = cfg.MODEL.AVISM.SIM_WEIGHT |
|
|
|
|
|
matcher = HungarianMatcher( |
|
cost_class=class_weight, |
|
cost_mask=mask_weight, |
|
cost_dice=dice_weight, |
|
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, |
|
) |
|
|
|
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight} |
|
|
|
if deep_supervision: |
|
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS |
|
aux_weight_dict = {} |
|
for i in range(dec_layers - 1): |
|
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) |
|
weight_dict.update(aux_weight_dict) |
|
|
|
losses = ["labels", "masks"] |
|
|
|
criterion = SetCriterion( |
|
sem_seg_head.num_classes, |
|
matcher=matcher, |
|
weight_dict=weight_dict, |
|
eos_coef=no_object_weight, |
|
losses=losses, |
|
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, |
|
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, |
|
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, |
|
avism_last_layer_num=cfg.MODEL.AVISM.LAST_LAYER_NUM, |
|
) |
|
|
|
|
|
num_classes = sem_seg_head.num_classes |
|
hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM |
|
avism_module = Avism_COCO(cfg=cfg, in_channels=hidden_dim, aux_loss=avism_deep_supervision) |
|
|
|
|
|
avism_matcher = AvismHungarianMatcher( |
|
cost_class=class_weight, |
|
cost_mask=mask_weight, |
|
cost_dice=dice_weight, |
|
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, |
|
) |
|
avism_weight_dict = { |
|
"loss_avism_ce": class_weight, "loss_avism_mask": mask_weight, "loss_avism_dice": dice_weight |
|
} |
|
if sim_weight > 0.0: |
|
avism_weight_dict["loss_avism_sim"] = sim_weight |
|
|
|
if avism_deep_supervision: |
|
avism_dec_layers = cfg.MODEL.AVISM.DEC_LAYERS |
|
aux_weight_dict = {} |
|
for i in range(avism_dec_layers - 1): |
|
aux_weight_dict.update({k + f"_{i}": v for k, v in avism_weight_dict.items()}) |
|
avism_weight_dict.update(aux_weight_dict) |
|
avism_losses = ["avism_labels", "avism_masks"] |
|
if sim_weight > 0.0: |
|
avism_losses.append("fg_sim") |
|
|
|
avism_criterion = AvismSetCriterion( |
|
num_classes, |
|
matcher=avism_matcher, |
|
weight_dict=avism_weight_dict, |
|
eos_coef=cfg.MODEL.AVISM.NO_OBJECT_WEIGHT, |
|
losses=avism_losses, |
|
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, |
|
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, |
|
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, |
|
sim_use_clip=cfg.MODEL.AVISM.SIM_USE_CLIP, |
|
) |
|
|
|
return { |
|
"backbone": backbone, |
|
"sem_seg_head": sem_seg_head, |
|
"criterion": criterion, |
|
"num_queries": cfg.MODEL.AVISM.NUM_OBJECT_QUERIES, |
|
"object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, |
|
"overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, |
|
"metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), |
|
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, |
|
"pixel_mean": cfg.MODEL.PIXEL_MEAN, |
|
"pixel_std": cfg.MODEL.PIXEL_STD, |
|
|
|
"test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE, |
|
|
|
"avism_module": avism_module, |
|
"avism_criterion": avism_criterion, |
|
"num_frames": cfg.INPUT.SAMPLING_FRAME_NUM, |
|
"num_classes": num_classes, |
|
"is_multi_cls": cfg.MODEL.AVISM.MULTI_CLS_ON, |
|
"apply_cls_thres": cfg.MODEL.AVISM.APPLY_CLS_THRES, |
|
"freeze_detector": cfg.MODEL.AVISM.FREEZE_DETECTOR, |
|
"test_run_chunk_size": cfg.MODEL.AVISM.TEST_RUN_CHUNK_SIZE, |
|
"test_interpolate_chunk_size": cfg.MODEL.AVISM.TEST_INTERPOLATE_CHUNK_SIZE, |
|
"is_coco": cfg.DATASETS.TEST[0].startswith("coco"), |
|
} |
|
|
|
@property |
|
def device(self): |
|
return self.pixel_mean.device |
|
|
|
def forward(self, batched_inputs): |
|
""" |
|
Args: |
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper`. |
|
Each item in the list contains the inputs for one image. |
|
For now, each item in the list is a dict that contains: |
|
* "image": Tensor, image in (C, H, W) format. |
|
* "instances": per-region ground truth |
|
* Other information that's included in the original dicts, such as: |
|
"height", "width" (int): the output resolution of the model (may be different |
|
from input resolution), used in inference. |
|
Returns: |
|
list[dict]: |
|
each dict has the results for one image. The dict contains the following keys: |
|
|
|
* "sem_seg": |
|
A Tensor that represents the |
|
per-pixel segmentation prediced by the head. |
|
The prediction has shape KxHxW that represents the logits of |
|
each class for each pixel. |
|
* "panoptic_seg": |
|
A tuple that represent panoptic output |
|
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. |
|
segments_info (list[dict]): Describe each segment in `panoptic_seg`. |
|
Each dict contains keys "id", "category_id", "isthing". |
|
""" |
|
if self.training: |
|
return self.train_model(batched_inputs) |
|
else: |
|
|
|
return self.inference(batched_inputs[0]) |
|
|
|
def train_model(self, batched_inputs): |
|
images = [] |
|
audio_features = [] |
|
for video in batched_inputs: |
|
for frame in video["image"]: |
|
images.append(frame.to(self.device)) |
|
for audio_feat in video["audio"]: |
|
audio_features.append(torch.tensor(audio_feat).to(self.device)) |
|
|
|
audio_features = torch.stack(audio_features) |
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images] |
|
images = ImageList.from_tensors(images, self.size_divisibility) |
|
image_features = self.backbone(images.tensor) |
|
|
|
BT = len(images) |
|
T = self.num_frames if self.training else BT |
|
B = BT // T |
|
|
|
outputs, frame_queries, mask_features = self.sem_seg_head(image_features, audio_features) |
|
|
|
mask_features = self.vita_module.vita_mask_features(mask_features) |
|
mask_features = mask_features.view(B, self.num_frames, *mask_features.shape[-3:]) |
|
|
|
|
|
frame_targets, clip_targets = self.prepare_targets(batched_inputs, images) |
|
|
|
|
|
losses, fg_indices = self.criterion(outputs, frame_targets) |
|
|
|
avism_outputs = self.vita_module(frame_queries, audio_features) |
|
avism_outputs["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", avism_outputs["pred_mask_embed"], mask_features) |
|
for out in avism_outputs["aux_outputs"]: |
|
out["pred_masks"] = torch.einsum("lbqc,btchw->lbqthw", out["pred_mask_embed"], mask_features) |
|
|
|
for k in list(losses.keys()): |
|
if k in self.criterion.weight_dict: |
|
losses[k] *= self.criterion.weight_dict[k] |
|
else: |
|
|
|
losses.pop(k) |
|
avism_loss_dict = self.vita_criterion(avism_outputs, clip_targets, frame_targets, fg_indices) |
|
avism_weight_dict = self.vita_criterion.weight_dict |
|
|
|
for k in avism_loss_dict.keys(): |
|
if k in avism_weight_dict: |
|
avism_loss_dict[k] *= avism_weight_dict[k] |
|
losses.update(avism_loss_dict) |
|
return losses |
|
|
|
def prepare_targets(self, targets, images): |
|
h_pad, w_pad = images.tensor.shape[-2:] |
|
frame_gt_instances = [] |
|
clip_gt_instances = [] |
|
for targets_per_video in targets: |
|
_num_instance = len(targets_per_video["instances"][0]) |
|
mask_shape = [_num_instance, self.num_frames, h_pad, w_pad] |
|
gt_masks_per_video = torch.zeros(mask_shape, dtype=torch.bool, device=self.device) |
|
|
|
gt_classes_per_video = targets_per_video["instances"][0].gt_classes.to(self.device) |
|
gt_ids_per_video = [] |
|
for f_i, targets_per_frame in enumerate(targets_per_video["instances"]): |
|
targets_per_frame = targets_per_frame.to(self.device) |
|
h, w = targets_per_frame.image_size |
|
|
|
_update_cls = gt_classes_per_video == -1 |
|
gt_classes_per_video[_update_cls] = targets_per_frame.gt_classes[_update_cls] |
|
gt_ids_per_video.append(targets_per_frame.gt_ids) |
|
if isinstance(targets_per_frame.gt_masks, BitMasks): |
|
gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks.tensor |
|
else: |
|
gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks |
|
|
|
gt_ids_per_video = torch.stack(gt_ids_per_video, dim=1) |
|
gt_ids_per_video[gt_masks_per_video.sum(dim=(2,3)) == 0] = -1 |
|
valid_bool_frame = (gt_ids_per_video != -1) |
|
valid_bool_clip = valid_bool_frame.any(dim=-1) |
|
|
|
gt_classes_per_video = gt_classes_per_video[valid_bool_clip].long() |
|
gt_ids_per_video = gt_ids_per_video[valid_bool_clip].long() |
|
gt_masks_per_video = gt_masks_per_video[valid_bool_clip].float() |
|
valid_bool_frame = valid_bool_frame[valid_bool_clip] |
|
|
|
if len(gt_ids_per_video) > 0: |
|
min_id = max(gt_ids_per_video[valid_bool_frame].min(), 0) |
|
gt_ids_per_video[valid_bool_frame] -= min_id |
|
|
|
clip_gt_instances.append( |
|
{ |
|
"labels": gt_classes_per_video, "ids": gt_ids_per_video, "masks": gt_masks_per_video, |
|
"video_len": targets_per_video["length"], "frame_idx": targets_per_video["frame_idx"], |
|
} |
|
) |
|
|
|
for f_i in range(self.num_frames): |
|
_cls = gt_classes_per_video.clone() |
|
_ids = gt_ids_per_video[:, f_i].clone() |
|
_mask = gt_masks_per_video[:, f_i].clone() |
|
|
|
valid = _ids != -1 |
|
frame_gt_instances.append({ |
|
"labels": _cls[valid], |
|
"ids": _ids[valid], |
|
"masks": _mask[valid], |
|
}) |
|
|
|
return frame_gt_instances, clip_gt_instances |
|
|
|
def inference(self, batched_inputs): |
|
frame_queries, mask_features = [], [] |
|
num_frames = len(batched_inputs["image"]) |
|
to_store = self.device if num_frames <= 36 else "cpu" |
|
|
|
audio_features = torch.tensor(batched_inputs["audio"]).to(self.device) |
|
|
|
with torch.no_grad(): |
|
for i in range(math.ceil(num_frames / self.test_run_chunk_size)): |
|
images = batched_inputs["image"][i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size] |
|
images = [(x.to(self.device) - self.pixel_mean) / self.pixel_std for x in images] |
|
images = ImageList.from_tensors(images, self.size_divisibility) |
|
|
|
audio_features_chunk = audio_features[i*self.test_run_chunk_size : (i+1)*self.test_run_chunk_size] |
|
|
|
features = self.backbone(images.tensor) |
|
outputs, _frame_queries, _mask_features = self.sem_seg_head(features, audio_features_chunk) |
|
|
|
_mask_features = self.vita_module.vita_mask_features(_mask_features) |
|
|
|
|
|
frame_queries.append(_frame_queries[-1]) |
|
mask_features.append(_mask_features.to(to_store)) |
|
|
|
interim_size = images.tensor.shape[-2:] |
|
image_size = images.image_sizes[0] |
|
|
|
out_height = batched_inputs.get("height", image_size[0]) |
|
out_width = batched_inputs.get("width", image_size[1]) |
|
|
|
del outputs, images, batched_inputs |
|
|
|
frame_queries = torch.cat(frame_queries)[None] |
|
mask_features = torch.cat(mask_features) |
|
|
|
avism_outputs = self.vita_module(frame_queries, audio_features) |
|
|
|
mask_cls = avism_outputs["pred_logits"][-1, 0] |
|
mask_embed = avism_outputs["pred_mask_embed"][-1, 0] |
|
|
|
del avism_outputs |
|
|
|
scores = F.softmax(mask_cls, dim=-1)[:, :-1] |
|
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) |
|
|
|
num_topk = self.test_topk_per_image |
|
scores_per_video, topk_indices = scores.flatten(0, 1).topk(num_topk, sorted=False) |
|
labels_per_video = labels[topk_indices] |
|
|
|
topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='floor') |
|
mask_embed = mask_embed[topk_indices] |
|
|
|
masks_per_video = [] |
|
numerator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device) |
|
denominator = torch.zeros(len(mask_embed), dtype=torch.float, device=self.device) |
|
for i in range(math.ceil(len(mask_features) / self.test_interpolate_chunk_size)): |
|
m_f = mask_features[i*self.test_interpolate_chunk_size : (i+1)*self.test_interpolate_chunk_size].to(self.device) |
|
|
|
mask_pred = torch.einsum("qc,tchw->qthw", mask_embed, m_f) |
|
|
|
|
|
mask_pred = retry_if_cuda_oom(F.interpolate)( |
|
mask_pred, |
|
size=interim_size, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
mask_pred = mask_pred[:, :, : image_size[0], : image_size[1]] |
|
|
|
interim_mask_soft = mask_pred.sigmoid() |
|
interim_mask_hard = interim_mask_soft > 0.5 |
|
|
|
numerator += (interim_mask_soft.flatten(1) * interim_mask_hard.flatten(1)).sum(1) |
|
denominator += interim_mask_hard.flatten(1).sum(1) |
|
|
|
mask_pred = F.interpolate( |
|
mask_pred, size=(out_height, out_width), mode="bilinear", align_corners=False |
|
) > 0. |
|
masks_per_video.append(mask_pred.to(to_store)) |
|
masks_per_video = torch.cat(masks_per_video, dim=1) |
|
scores_per_video *= (numerator / (denominator + 1e-6)) |
|
|
|
confidence = 0.3 |
|
indices = torch.nonzero(scores_per_video > confidence).squeeze(-1) |
|
scores_per_video = scores_per_video[indices] |
|
labels_per_video = labels_per_video[indices] |
|
masks_per_video = masks_per_video[indices] |
|
|
|
processed_results = { |
|
"image_size": (out_height, out_width), |
|
"pred_scores": scores_per_video.tolist(), |
|
"pred_labels": labels_per_video.tolist(), |
|
"pred_masks": masks_per_video.cpu(), |
|
} |
|
|
|
return processed_results |
|
|