|
import logging |
|
import fvcore.nn.weight_init as weight_init |
|
from typing import Optional |
|
import torch |
|
from torch import nn, Tensor |
|
from torch.nn import functional as F |
|
from math import ceil |
|
|
|
from detectron2.config import configurable |
|
from detectron2.layers import Conv2d |
|
|
|
from .position_encoding import PositionEmbeddingSine |
|
from mask2former.modeling.transformer_decoder.maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY |
|
|
|
|
|
class SelfAttentionLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, tgt, |
|
tgt_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
q = k = self.with_pos_embed(tgt, query_pos) |
|
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, |
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
tgt = tgt + self.dropout(tgt2) |
|
tgt = self.norm(tgt) |
|
|
|
return tgt |
|
|
|
def forward_pre(self, tgt, |
|
tgt_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
tgt2 = self.norm(tgt) |
|
q = k = self.with_pos_embed(tgt2, query_pos) |
|
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, |
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
return tgt |
|
|
|
def forward(self, tgt, |
|
tgt_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
if self.normalize_before: |
|
return self.forward_pre(tgt, tgt_mask, |
|
tgt_key_padding_mask, query_pos) |
|
return self.forward_post(tgt, tgt_mask, |
|
tgt_key_padding_mask, query_pos) |
|
|
|
|
|
class CrossAttentionLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, tgt, memory, |
|
memory_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), |
|
key=self.with_pos_embed(memory, pos), |
|
value=memory, attn_mask=memory_mask, |
|
key_padding_mask=memory_key_padding_mask)[0] |
|
tgt = tgt + self.dropout(tgt2) |
|
tgt = self.norm(tgt) |
|
|
|
return tgt |
|
|
|
def forward_pre(self, tgt, memory, |
|
memory_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
tgt2 = self.norm(tgt) |
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), |
|
key=self.with_pos_embed(memory, pos), |
|
value=memory, attn_mask=memory_mask, |
|
key_padding_mask=memory_key_padding_mask)[0] |
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
return tgt |
|
|
|
def forward(self, tgt, memory, |
|
memory_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
if self.normalize_before: |
|
return self.forward_pre(tgt, memory, memory_mask, |
|
memory_key_padding_mask, pos, query_pos) |
|
return self.forward_post(tgt, memory, memory_mask, |
|
memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
class FFNLayer(nn.Module): |
|
|
|
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, tgt): |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout(tgt2) |
|
tgt = self.norm(tgt) |
|
return tgt |
|
|
|
def forward_pre(self, tgt): |
|
tgt2 = self.norm(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
tgt = tgt + self.dropout(tgt2) |
|
return tgt |
|
|
|
def forward(self, tgt): |
|
if self.normalize_before: |
|
return self.forward_pre(tgt) |
|
return self.forward_post(tgt) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
""" |
|
compute sinusoid encoding. |
|
""" |
|
|
|
def __init__(self, d_model, max_len, device): |
|
""" |
|
constructor of sinusoid encoding class |
|
|
|
:param d_model: dimension of model |
|
:param max_len: max sequence length |
|
:param device: hardware device setting |
|
""" |
|
super(PositionalEncoding, self).__init__() |
|
|
|
|
|
self.encoding = torch.zeros(max_len, d_model, device=device) |
|
self.encoding.requires_grad = False |
|
|
|
pos = torch.arange(0, max_len, device=device) |
|
pos = pos.float().unsqueeze(dim=1) |
|
|
|
|
|
_2i = torch.arange(0, d_model, step=2, device=device).float() |
|
|
|
|
|
|
|
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))) |
|
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
batch_size, seq_len = x.size() |
|
|
|
|
|
return self.encoding[:seq_len, :] |
|
|
|
|
|
|
|
|
|
@TRANSFORMER_DECODER_REGISTRY.register() |
|
class AVISMMultiScaleMaskedTransformerDecoder(nn.Module): |
|
|
|
_version = 2 |
|
|
|
def _load_from_state_dict( |
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
): |
|
version = local_metadata.get("version", None) |
|
if version is None or version < 2: |
|
|
|
scratch = True |
|
logger = logging.getLogger(__name__) |
|
for k in list(state_dict.keys()): |
|
newk = k |
|
if "static_query" in k: |
|
newk = k.replace("static_query", "query_feat") |
|
if newk != k: |
|
state_dict[newk] = state_dict[k] |
|
del state_dict[k] |
|
scratch = False |
|
|
|
if not scratch: |
|
logger.warning( |
|
f"Weight format of {self.__class__.__name__} have changed! " |
|
"Please upgrade your models. Applying automatic conversion now ..." |
|
) |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
in_channels, |
|
mask_classification=True, |
|
*, |
|
num_classes: int, |
|
hidden_dim: int, |
|
num_queries: int, |
|
nheads: int, |
|
dim_feedforward: int, |
|
dec_layers: int, |
|
pre_norm: bool, |
|
mask_dim: int, |
|
enforce_input_project: bool, |
|
avism_last_layer_num: int, |
|
): |
|
""" |
|
NOTE: this interface is experimental. |
|
Args: |
|
in_channels: channels of the input features |
|
mask_classification: whether to add mask classifier or not |
|
num_classes: number of classes |
|
hidden_dim: Transformer feature dimension |
|
num_queries: number of queries |
|
nheads: number of heads |
|
dim_feedforward: feature dimension in feedforward network |
|
enc_layers: number of Transformer encoder layers |
|
dec_layers: number of Transformer decoder layers |
|
pre_norm: whether to use pre-LayerNorm or not |
|
mask_dim: mask feature dimension |
|
enforce_input_project: add input project 1x1 conv even if input |
|
channels and hidden dim is identical |
|
""" |
|
super().__init__() |
|
|
|
assert mask_classification, "Only support mask classification model" |
|
self.mask_classification = mask_classification |
|
|
|
|
|
N_steps = hidden_dim // 2 |
|
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) |
|
|
|
self.av_pre_proj = nn.Linear(128, hidden_dim) |
|
self.av_sf = nn.ModuleList() |
|
for _ in range(3): |
|
self.av_sf.append( |
|
CrossAttentionLayer(d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm)) |
|
self.av_post_proj = nn.Linear(hidden_dim * 3, hidden_dim) |
|
|
|
|
|
self.num_heads = nheads |
|
self.num_layers = dec_layers |
|
self.transformer_self_attention_layers = nn.ModuleList() |
|
self.transformer_cross_attention_layers = nn.ModuleList() |
|
self.transformer_av_cross_attention_layers = nn.ModuleList() |
|
self.transformer_ffn_layers = nn.ModuleList() |
|
|
|
for _ in range(self.num_layers): |
|
self.transformer_self_attention_layers.append( |
|
SelfAttentionLayer( |
|
d_model=hidden_dim, |
|
nhead=nheads, |
|
dropout=0.0, |
|
normalize_before=pre_norm, |
|
) |
|
) |
|
|
|
self.transformer_cross_attention_layers.append( |
|
CrossAttentionLayer( |
|
d_model=hidden_dim, |
|
nhead=nheads, |
|
dropout=0.0, |
|
normalize_before=pre_norm, |
|
) |
|
) |
|
|
|
self.transformer_ffn_layers.append( |
|
FFNLayer( |
|
d_model=hidden_dim, |
|
dim_feedforward=dim_feedforward, |
|
dropout=0.0, |
|
normalize_before=pre_norm, |
|
) |
|
) |
|
|
|
self.decoder_norm = nn.LayerNorm(hidden_dim) |
|
|
|
self.num_queries = num_queries |
|
|
|
self.query_feat = nn.Embedding(num_queries, hidden_dim) |
|
|
|
self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
|
|
|
self.num_feature_levels = 3 |
|
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) |
|
self.input_proj = nn.ModuleList() |
|
for _ in range(self.num_feature_levels): |
|
if in_channels != hidden_dim or enforce_input_project: |
|
self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) |
|
weight_init.c2_xavier_fill(self.input_proj[-1]) |
|
else: |
|
self.input_proj.append(nn.Sequential()) |
|
|
|
|
|
if self.mask_classification: |
|
self.class_embed = nn.Linear(hidden_dim, num_classes + 1) |
|
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
|
|
|
self.avism_last_layer_num = avism_last_layer_num |
|
|
|
@classmethod |
|
def from_config(cls, cfg, in_channels, mask_classification): |
|
ret = {} |
|
ret["in_channels"] = in_channels |
|
ret["mask_classification"] = mask_classification |
|
|
|
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES |
|
ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM |
|
ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES |
|
|
|
ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS |
|
ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1 |
|
ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 |
|
ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM |
|
ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ |
|
|
|
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM |
|
ret["avism_last_layer_num"] = cfg.MODEL.AVISM.LAST_LAYER_NUM |
|
|
|
return ret |
|
|
|
def forward(self, x, mask_features, clip_mask_features, audio_features, mask = None): |
|
|
|
assert len(x) == self.num_feature_levels |
|
src = [] |
|
pos = [] |
|
size_list = [] |
|
|
|
|
|
del mask |
|
|
|
for i in range(self.num_feature_levels): |
|
size_list.append(x[i].shape[-2:]) |
|
pos.append(self.pe_layer(x[i], None).flatten(2)) |
|
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) |
|
|
|
|
|
pos[-1] = pos[-1].permute(2, 0, 1) |
|
src[-1] = src[-1].permute(2, 0, 1) |
|
|
|
_, bs, _ = src[0].shape |
|
|
|
|
|
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) |
|
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) |
|
|
|
frame_queries = [] |
|
predictions_class = [] |
|
predictions_mask = [] |
|
|
|
|
|
audio_feat = self.av_pre_proj(audio_features) |
|
audio_feat = audio_feat[None, :, :] |
|
av_feats = [] |
|
for l in range(len(src)): |
|
av_feat = self.av_sf[l](audio_feat, src[l], query_pos=query_embed) |
|
av_feats.append(av_feat) |
|
audio_feats_ml = self.av_post_proj(torch.cat((av_feats[0], av_feats[1], av_feats[2]), dim=-1)) |
|
output = output + audio_feats_ml |
|
|
|
|
|
outputs_class, outputs_mask, attn_mask, frame_query = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0]) |
|
predictions_class.append(outputs_class) |
|
predictions_mask.append(outputs_mask) |
|
|
|
for i in range(self.num_layers): |
|
level_index = i % self.num_feature_levels |
|
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False |
|
|
|
output = self.transformer_cross_attention_layers[i]( |
|
output, src[level_index], |
|
memory_mask=attn_mask, |
|
memory_key_padding_mask=None, |
|
pos=pos[level_index], query_pos=query_embed |
|
) |
|
|
|
output = self.transformer_self_attention_layers[i]( |
|
output, tgt_mask=None, |
|
tgt_key_padding_mask=None, |
|
query_pos=query_embed |
|
) |
|
|
|
|
|
output = self.transformer_ffn_layers[i]( |
|
output |
|
) |
|
|
|
outputs_class, outputs_mask, attn_mask, frame_query = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]) |
|
frame_queries.append(frame_query) |
|
predictions_class.append(outputs_class) |
|
predictions_mask.append(outputs_mask) |
|
|
|
assert len(predictions_class) == self.num_layers + 1 |
|
|
|
out = { |
|
'pred_logits': predictions_class[-1], |
|
'pred_masks': predictions_mask[-1], |
|
'aux_outputs': self._set_aux_loss( |
|
predictions_class if self.mask_classification else None, predictions_mask |
|
) |
|
} |
|
|
|
num_layer = self.avism_last_layer_num if self.training else 1 |
|
frame_queries = torch.stack(frame_queries[-num_layer:]) |
|
|
|
return out, frame_queries, clip_mask_features |
|
|
|
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): |
|
decoder_output = self.decoder_norm(output) |
|
decoder_output = decoder_output.transpose(0, 1) |
|
outputs_class = self.class_embed(decoder_output) |
|
mask_embed = self.mask_embed(decoder_output) |
|
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) |
|
|
|
|
|
|
|
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) |
|
|
|
|
|
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() |
|
attn_mask = attn_mask.detach() |
|
|
|
return outputs_class, outputs_mask, attn_mask, decoder_output |
|
|
|
@torch.jit.unused |
|
def _set_aux_loss(self, outputs_class, outputs_seg_masks): |
|
|
|
|
|
|
|
if self.mask_classification: |
|
return [ |
|
{"pred_logits": a, "pred_masks": b} |
|
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) |
|
] |
|
else: |
|
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] |
|
|