VRIS_vip / models /segmentation.py
dianecy's picture
Add files using upload-large-folder tool
9b855a7 verified
"""
Segmentaion Part
Modified from DETR (https://github.com/facebookresearch/detr)
"""
from collections import defaultdict
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from PIL import Image
from einops import rearrange, repeat
try:
from panopticapi.utils import id2rgb, rgb2id
except ImportError:
pass
import fvcore.nn.weight_init as weight_init
from .position_encoding import PositionEmbeddingSine1D
BN_MOMENTUM = 0.1
def get_norm(norm, out_channels): # only support GN or LN
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"GN": lambda channels: nn.GroupNorm(8, channels),
"LN": lambda channels: nn.LayerNorm(channels)
}[norm]
return norm(out_channels)
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
# FPN structure
class CrossModalFPNDecoder(nn.Module):
def __init__(self, feature_channels: List, conv_dim: int, mask_dim: int, dim_feedforward: int = 2048, norm=None):
"""
Args:
feature_channels: list of fpn feature channel numbers.
conv_dim: number of output channels for the intermediate conv layers.
mask_dim: number of output channels for the final conv layer.
dim_feedforward: number of vision-language fusion module ffn channel numbers.
norm (str or callable): normalization for all conv layers
"""
super().__init__()
self.feature_channels = feature_channels
lateral_convs = []
output_convs = []
use_bias = norm == ""
for idx, in_channels in enumerate(feature_channels):
# in_channels: 4x -> 32x
lateral_norm = get_norm(norm, conv_dim)
output_norm = get_norm(norm, conv_dim)
lateral_conv = Conv2d(
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
)
output_conv = Conv2d(
conv_dim,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
weight_init.c2_xavier_fill(lateral_conv)
weight_init.c2_xavier_fill(output_conv)
stage = idx+1
self.add_module("adapter_{}".format(stage), lateral_conv)
self.add_module("layer_{}".format(stage), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
self.mask_dim = mask_dim
self.mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=3,
stride=1,
padding=1,
)
weight_init.c2_xavier_fill(self.mask_features)
# vision-language cross-modal fusion
self.text_pos = PositionEmbeddingSine1D(conv_dim, normalize=True)
sr_ratios = [8, 4, 2, 1]
cross_attns = []
for idx in range(len(feature_channels)): # res2 -> res5
cross_attn = VisionLanguageBlock(conv_dim, dim_feedforward=dim_feedforward,
nhead=8, sr_ratio=sr_ratios[idx])
for p in cross_attn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
stage = int(idx + 1)
self.add_module("cross_attn_{}".format(stage), cross_attn)
cross_attns.append(cross_attn)
# place cross-attn in top-down order (from low to high resolution)
self.cross_attns = cross_attns[::-1]
def forward_features(self, features, text_features, poses, memory, nf):
# nf: num_frames
text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c]
text_features, text_masks = text_features.decompose()
text_features = text_features.permute(1, 0, 2)
for idx, (mem, f, pos) in enumerate(zip(memory[::-1], features[1:][::-1], poses[1:][::-1])): # 32x -> 8x
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cross_attn = self.cross_attns[idx]
_, x_mask = f.decompose()
n, c, h, w = pos.shape
b = n // nf
t = nf
# NOTE: here the (h, w) is the size for current fpn layer
vision_features = lateral_conv(mem) # [b*t, c, h, w]
vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t)
cur_fpn = cross_attn(tgt=vision_features,
memory=text_features,
t=t, h=h, w=w,
tgt_key_padding_mask=vision_masks,
memory_key_padding_mask=text_masks,
pos=text_pos,
query_pos=vision_pos
) # [t*h*w, b, c]
cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# upsample
if idx == 0: # top layer
y = output_conv(cur_fpn)
else:
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
y = output_conv(y)
# 4x level
lateral_conv = self.lateral_convs[-1]
output_conv = self.output_convs[-1]
cross_attn = self.cross_attns[-1]
x, x_mask = features[0].decompose()
pos = poses[0]
n, c, h, w = pos.shape
b = n // nf
t = nf
vision_features = lateral_conv(x) # [b*t, c, h, w]
vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t)
cur_fpn = cross_attn(tgt=vision_features,
memory=text_features,
t=t, h=h, w=w,
tgt_key_padding_mask=vision_masks,
memory_key_padding_mask=text_masks,
pos=text_pos,
query_pos=vision_pos
) # [t*h*w, b, c]
cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
y = output_conv(y)
return y # [b*t, c, h, w], the spatial stride is 4x
def forward(self, features, text_features, pos, memory, nf):
"""The forward function receives the vision and language features,
and outputs the mask features with the spatial stride of 4x.
Args:
features (list[NestedTensor]): backbone features (vision), length is number of FPN layers
tensors: [b*t, ci, hi, wi], mask: [b*t, hi, wi]
text_features (NestedTensor): text features (language)
tensors: [b, length, c], mask: [b, length]
pos (list[Tensor]): position encoding of vision features, length is number of FPN layers
tensors: [b*t, c, hi, wi]
memory (list[Tensor]): features from encoder output. from 8x -> 32x
NOTE: the layer orders of both features and pos are res2 -> res5
Returns:
mask_features (Tensor): [b*t, mask_dim, h, w], with the spatial stride of 4x.
"""
y = self.forward_features(features, text_features, pos, memory, nf)
return self.mask_features(y)
class VisionLanguageBlock(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, sr_ratio=1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
# for downsample
self.sr_ratio = sr_ratio
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory, t, h, w,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
b = tgt.size(1)
# self attn
q = k = self.with_pos_embed(tgt, query_pos)
if self.sr_ratio > 1: # downsample
q = rearrange(q, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
k = rearrange(k, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
v = rearrange(tgt, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# downsample
new_h = int(h * 1./self.sr_ratio)
new_w = int(w * 1./self.sr_ratio)
size = (new_h, new_w)
q = F.interpolate(q, size=size, mode='nearest')
k = F.interpolate(k, size=size, mode='nearest')
v = F.interpolate(v, size=size, mode='nearest')
# shape for transformer
q = rearrange(q, '(b t) c h w -> (t h w) b c', t=t)
k = rearrange(k, '(b t) c h w -> (t h w) b c', t=t)
v = rearrange(v, '(b t) c h w -> (t h w) b c', t=t)
# downsample mask
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b*t, h, w)
tgt_key_padding_mask = F.interpolate(tgt_key_padding_mask[None].float(), size=(new_h, new_w), mode='nearest').bool()[0]
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b, t, new_h, new_w).flatten(1)
else:
v = tgt
tgt2 = self.self_attn(q, k, value=v, attn_mask=None,
key_padding_mask=tgt_key_padding_mask)[0] # [H*W, B*T, C]
if self.sr_ratio > 1:
tgt2 = rearrange(tgt2, '(t h w) b c -> (b t) c h w', t=t, h=new_h, w=new_w)
size = (h, w) # recover to origin size
tgt2 = F.interpolate(tgt2, size=size, mode='bilinear', align_corners=False) # [B*T, C, H, W]
tgt2 = rearrange(tgt2, '(b t) c h w -> (t h w) b c', t=t)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# cross attn
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# ffn
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory, t, h, w,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
b = tgt.size(1)
# self attn
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
if self.sr_ratio > 1: # downsample
q = rearrange(q, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
k = rearrange(k, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
v = rearrange(tgt, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# downsample
new_h = int(h * 1./self.sr_ratio)
new_w = int(w * 1./self.sr_ratio)
size = (new_h, new_w)
q = F.interpolate(q, size=size, mode='nearest')
k = F.interpolate(k, size=size, mode='nearest')
v = F.interpolate(v, size=size, mode='nearest')
# shape for transformer
q = rearrange(q, '(b t) c h w -> (t h w) b c', t=t)
k = rearrange(k, '(b t) c h w -> (t h w) b c', t=t)
v = rearrange(v, '(b t) c h w -> (t h w) b c', t=t)
# downsample mask
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b*t, h, w)
tgt_key_padding_mask = F.interpolate(tgt_key_padding_mask[None].float(), size=(new_h, new_w), mode='nearest').bool()[0]
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b, t, new_h, new_w).flatten(1)
else:
v = tgt2
tgt2 = self.self_attn(q, k, value=v, attn_mask=None,
key_padding_mask=tgt_key_padding_mask)[0] # [T*H*W, B, C]
if self.sr_ratio > 1:
tgt2 = rearrange(tgt2, '(t h w) b c -> (b t) c h w', t=t, h=new_h, w=new_w)
size = (h, w) # recover to origin size
tgt2 = F.interpolate(tgt2, size=size, mode='bilinear', align_corners=False) # [B*T, C, H, W]
tgt2 = rearrange(tgt2, '(b t) c h w -> (t h w) b c', t=t)
tgt = tgt + self.dropout1(tgt2)
# cross attn
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
# ffn
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory, t, h, w,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, t, h, w,
tgt_key_padding_mask, memory_key_padding_mask,
pos, query_pos)
return self.forward_post(tgt, memory, t, h, w,
tgt_key_padding_mask, memory_key_padding_mask,
pos, query_pos)
class VisionLanguageFusionModule(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt * tgt2
return tgt
def dice_loss(inputs, targets, num_boxes):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_boxes
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_boxes
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")