File size: 20,614 Bytes
9b855a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 |
"""
Segmentaion Part
Modified from DETR (https://github.com/facebookresearch/detr)
"""
from collections import defaultdict
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from PIL import Image
from einops import rearrange, repeat
try:
from panopticapi.utils import id2rgb, rgb2id
except ImportError:
pass
import fvcore.nn.weight_init as weight_init
from .position_encoding import PositionEmbeddingSine1D
BN_MOMENTUM = 0.1
def get_norm(norm, out_channels): # only support GN or LN
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"GN": lambda channels: nn.GroupNorm(8, channels),
"LN": lambda channels: nn.LayerNorm(channels)
}[norm]
return norm(out_channels)
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
# FPN structure
class CrossModalFPNDecoder(nn.Module):
def __init__(self, feature_channels: List, conv_dim: int, mask_dim: int, dim_feedforward: int = 2048, norm=None):
"""
Args:
feature_channels: list of fpn feature channel numbers.
conv_dim: number of output channels for the intermediate conv layers.
mask_dim: number of output channels for the final conv layer.
dim_feedforward: number of vision-language fusion module ffn channel numbers.
norm (str or callable): normalization for all conv layers
"""
super().__init__()
self.feature_channels = feature_channels
lateral_convs = []
output_convs = []
use_bias = norm == ""
for idx, in_channels in enumerate(feature_channels):
# in_channels: 4x -> 32x
lateral_norm = get_norm(norm, conv_dim)
output_norm = get_norm(norm, conv_dim)
lateral_conv = Conv2d(
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
)
output_conv = Conv2d(
conv_dim,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
weight_init.c2_xavier_fill(lateral_conv)
weight_init.c2_xavier_fill(output_conv)
stage = idx+1
self.add_module("adapter_{}".format(stage), lateral_conv)
self.add_module("layer_{}".format(stage), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
self.mask_dim = mask_dim
self.mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=3,
stride=1,
padding=1,
)
weight_init.c2_xavier_fill(self.mask_features)
# vision-language cross-modal fusion
self.text_pos = PositionEmbeddingSine1D(conv_dim, normalize=True)
sr_ratios = [8, 4, 2, 1]
cross_attns = []
for idx in range(len(feature_channels)): # res2 -> res5
cross_attn = VisionLanguageBlock(conv_dim, dim_feedforward=dim_feedforward,
nhead=8, sr_ratio=sr_ratios[idx])
for p in cross_attn.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
stage = int(idx + 1)
self.add_module("cross_attn_{}".format(stage), cross_attn)
cross_attns.append(cross_attn)
# place cross-attn in top-down order (from low to high resolution)
self.cross_attns = cross_attns[::-1]
def forward_features(self, features, text_features, poses, memory, nf):
# nf: num_frames
text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c]
text_features, text_masks = text_features.decompose()
text_features = text_features.permute(1, 0, 2)
for idx, (mem, f, pos) in enumerate(zip(memory[::-1], features[1:][::-1], poses[1:][::-1])): # 32x -> 8x
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cross_attn = self.cross_attns[idx]
_, x_mask = f.decompose()
n, c, h, w = pos.shape
b = n // nf
t = nf
# NOTE: here the (h, w) is the size for current fpn layer
vision_features = lateral_conv(mem) # [b*t, c, h, w]
vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t)
cur_fpn = cross_attn(tgt=vision_features,
memory=text_features,
t=t, h=h, w=w,
tgt_key_padding_mask=vision_masks,
memory_key_padding_mask=text_masks,
pos=text_pos,
query_pos=vision_pos
) # [t*h*w, b, c]
cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# upsample
if idx == 0: # top layer
y = output_conv(cur_fpn)
else:
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
y = output_conv(y)
# 4x level
lateral_conv = self.lateral_convs[-1]
output_conv = self.output_convs[-1]
cross_attn = self.cross_attns[-1]
x, x_mask = features[0].decompose()
pos = poses[0]
n, c, h, w = pos.shape
b = n // nf
t = nf
vision_features = lateral_conv(x) # [b*t, c, h, w]
vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t)
vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t)
cur_fpn = cross_attn(tgt=vision_features,
memory=text_features,
t=t, h=h, w=w,
tgt_key_padding_mask=vision_masks,
memory_key_padding_mask=text_masks,
pos=text_pos,
query_pos=vision_pos
) # [t*h*w, b, c]
cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
y = output_conv(y)
return y # [b*t, c, h, w], the spatial stride is 4x
def forward(self, features, text_features, pos, memory, nf):
"""The forward function receives the vision and language features,
and outputs the mask features with the spatial stride of 4x.
Args:
features (list[NestedTensor]): backbone features (vision), length is number of FPN layers
tensors: [b*t, ci, hi, wi], mask: [b*t, hi, wi]
text_features (NestedTensor): text features (language)
tensors: [b, length, c], mask: [b, length]
pos (list[Tensor]): position encoding of vision features, length is number of FPN layers
tensors: [b*t, c, hi, wi]
memory (list[Tensor]): features from encoder output. from 8x -> 32x
NOTE: the layer orders of both features and pos are res2 -> res5
Returns:
mask_features (Tensor): [b*t, mask_dim, h, w], with the spatial stride of 4x.
"""
y = self.forward_features(features, text_features, pos, memory, nf)
return self.mask_features(y)
class VisionLanguageBlock(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, sr_ratio=1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
# for downsample
self.sr_ratio = sr_ratio
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory, t, h, w,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
b = tgt.size(1)
# self attn
q = k = self.with_pos_embed(tgt, query_pos)
if self.sr_ratio > 1: # downsample
q = rearrange(q, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
k = rearrange(k, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
v = rearrange(tgt, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# downsample
new_h = int(h * 1./self.sr_ratio)
new_w = int(w * 1./self.sr_ratio)
size = (new_h, new_w)
q = F.interpolate(q, size=size, mode='nearest')
k = F.interpolate(k, size=size, mode='nearest')
v = F.interpolate(v, size=size, mode='nearest')
# shape for transformer
q = rearrange(q, '(b t) c h w -> (t h w) b c', t=t)
k = rearrange(k, '(b t) c h w -> (t h w) b c', t=t)
v = rearrange(v, '(b t) c h w -> (t h w) b c', t=t)
# downsample mask
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b*t, h, w)
tgt_key_padding_mask = F.interpolate(tgt_key_padding_mask[None].float(), size=(new_h, new_w), mode='nearest').bool()[0]
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b, t, new_h, new_w).flatten(1)
else:
v = tgt
tgt2 = self.self_attn(q, k, value=v, attn_mask=None,
key_padding_mask=tgt_key_padding_mask)[0] # [H*W, B*T, C]
if self.sr_ratio > 1:
tgt2 = rearrange(tgt2, '(t h w) b c -> (b t) c h w', t=t, h=new_h, w=new_w)
size = (h, w) # recover to origin size
tgt2 = F.interpolate(tgt2, size=size, mode='bilinear', align_corners=False) # [B*T, C, H, W]
tgt2 = rearrange(tgt2, '(b t) c h w -> (t h w) b c', t=t)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# cross attn
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# ffn
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory, t, h, w,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
b = tgt.size(1)
# self attn
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
if self.sr_ratio > 1: # downsample
q = rearrange(q, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
k = rearrange(k, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
v = rearrange(tgt, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
# downsample
new_h = int(h * 1./self.sr_ratio)
new_w = int(w * 1./self.sr_ratio)
size = (new_h, new_w)
q = F.interpolate(q, size=size, mode='nearest')
k = F.interpolate(k, size=size, mode='nearest')
v = F.interpolate(v, size=size, mode='nearest')
# shape for transformer
q = rearrange(q, '(b t) c h w -> (t h w) b c', t=t)
k = rearrange(k, '(b t) c h w -> (t h w) b c', t=t)
v = rearrange(v, '(b t) c h w -> (t h w) b c', t=t)
# downsample mask
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b*t, h, w)
tgt_key_padding_mask = F.interpolate(tgt_key_padding_mask[None].float(), size=(new_h, new_w), mode='nearest').bool()[0]
tgt_key_padding_mask = tgt_key_padding_mask.reshape(b, t, new_h, new_w).flatten(1)
else:
v = tgt2
tgt2 = self.self_attn(q, k, value=v, attn_mask=None,
key_padding_mask=tgt_key_padding_mask)[0] # [T*H*W, B, C]
if self.sr_ratio > 1:
tgt2 = rearrange(tgt2, '(t h w) b c -> (b t) c h w', t=t, h=new_h, w=new_w)
size = (h, w) # recover to origin size
tgt2 = F.interpolate(tgt2, size=size, mode='bilinear', align_corners=False) # [B*T, C, H, W]
tgt2 = rearrange(tgt2, '(b t) c h w -> (t h w) b c', t=t)
tgt = tgt + self.dropout1(tgt2)
# cross attn
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
# ffn
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory, t, h, w,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, t, h, w,
tgt_key_padding_mask, memory_key_padding_mask,
pos, query_pos)
return self.forward_post(tgt, memory, t, h, w,
tgt_key_padding_mask, memory_key_padding_mask,
pos, query_pos)
class VisionLanguageFusionModule(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt * tgt2
return tgt
def dice_loss(inputs, targets, num_boxes):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_boxes
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_boxes
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|