VRIS_vip / models /referformer.py
dianecy's picture
Add files using upload-large-folder tool
3ec4928 verified
"""
ReferFormer model class.
Modified from DETR (https://github.com/facebookresearch/detr)
"""
import torch
import torch.nn.functional as F
from torch import nn
import os
import math
from util import box_ops
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
nested_tensor_from_videos_list,
accuracy, get_world_size, interpolate,
is_dist_avail_and_initialized, inverse_sigmoid)
from .position_encoding import PositionEmbeddingSine1D
from .backbone import build_backbone
from .deformable_transformer import build_deforamble_transformer
from .segmentation import CrossModalFPNDecoder, VisionLanguageFusionModule
from .matcher import build_matcher
from .criterion import SetCriterion
from .postprocessors import build_postprocessors
from transformers import BertTokenizer, BertModel, RobertaModel, RobertaTokenizerFast
import copy
from einops import rearrange, repeat
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
os.environ["TOKENIZERS_PARALLELISM"] = "false" # this disables a huggingface tokenizer warning (printed every epoch)
class ReferFormer(nn.Module):
""" This is the ReferFormer module that performs referring video object detection """
def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
num_frames, mask_dim, dim_feedforward,
controller_layers, dynamic_mask_channels,
aux_loss=False, with_box_refine=False, two_stage=False,
freeze_text_encoder=False, rel_coord=True):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
ReferFormer can detect in a video. For ytvos, we recommend 5 queries for each frame.
num_frames: number of clip frames
mask_dim: dynamic conv inter layer channel number.
dim_feedforward: vision-language fusion module ffn channel number.
dynamic_mask_channels: the mask feature output channel number.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.hidden_dim = hidden_dim
self.class_embed = nn.Linear(hidden_dim, num_classes)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.num_feature_levels = num_feature_levels
# Build Transformer
# NOTE: different deformable detr, the query_embed out channels is
# hidden_dim instead of hidden_dim * 2
# This is because, the input to the decoder is text embedding feature
self.query_embed = nn.Embedding(num_queries, hidden_dim)
# follow deformable-detr, we use the last three stages of backbone
if num_feature_levels > 1:
num_backbone_outs = len(backbone.strides[-3:])
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[-3:][_]
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
))
for _ in range(num_feature_levels - num_backbone_outs): # downsample 2x
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
))
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone.num_channels[-3:][0], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)])
self.num_frames = num_frames
self.mask_dim = mask_dim
self.backbone = backbone
self.aux_loss = aux_loss
self.with_box_refine = with_box_refine
assert two_stage == False, "args.two_stage must be false!"
# initialization
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
num_pred = transformer.decoder.num_layers
if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
# hack implementation for iterative bounding box refinement
self.transformer.decoder.bbox_embed = self.bbox_embed
else:
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
self.transformer.decoder.bbox_embed = None
# Build Text Encoder
# self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# self.text_encoder = BertModel.from_pretrained('bert-base-cased')
self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
self.text_encoder = RobertaModel.from_pretrained('roberta-base')
if freeze_text_encoder:
for p in self.text_encoder.parameters():
p.requires_grad_(False)
# resize the bert output channel to transformer d_model
self.resizer = FeatureResizer(
input_feat_size=768,
output_feat_size=hidden_dim,
dropout=0.1,
)
self.fusion_module = VisionLanguageFusionModule(d_model=hidden_dim, nhead=8)
self.text_pos = PositionEmbeddingSine1D(hidden_dim, normalize=True)
# Build FPN Decoder
self.rel_coord = rel_coord
feature_channels = [self.backbone.num_channels[0]] + 3 * [hidden_dim]
self.pixel_decoder = CrossModalFPNDecoder(feature_channels=feature_channels, conv_dim=hidden_dim,
mask_dim=mask_dim, dim_feedforward=dim_feedforward, norm="GN")
# Build Dynamic Conv
self.controller_layers = controller_layers
self.in_channels = mask_dim
self.dynamic_mask_channels = dynamic_mask_channels
self.mask_out_stride = 4
self.mask_feat_stride = 4
weight_nums, bias_nums = [], []
for l in range(self.controller_layers):
if l == 0:
if self.rel_coord:
weight_nums.append((self.in_channels + 2) * self.dynamic_mask_channels)
else:
weight_nums.append(self.in_channels * self.dynamic_mask_channels)
bias_nums.append(self.dynamic_mask_channels)
elif l == self.controller_layers - 1:
weight_nums.append(self.dynamic_mask_channels * 1) # output layer c -> 1
bias_nums.append(1)
else:
weight_nums.append(self.dynamic_mask_channels * self.dynamic_mask_channels)
bias_nums.append(self.dynamic_mask_channels)
self.weight_nums = weight_nums
self.bias_nums = bias_nums
self.num_gen_params = sum(weight_nums) + sum(bias_nums)
self.controller = MLP(hidden_dim, hidden_dim, self.num_gen_params, 3)
for layer in self.controller.layers:
nn.init.zeros_(layer.bias)
nn.init.xavier_uniform_(layer.weight)
def forward(self, samples: NestedTensor, captions, targets):
""" The forward expects a NestedTensor, which consists of:
- samples.tensors: image sequences, of shape [num_frames x 3 x H x W]
- samples.mask: a binary mask of shape [num_frames x H x W], containing 1 on padded pixels
- captions: list[str]
- targets: list[dict]
It returns a dict with the following elements:
- "pred_masks": Shape = [batch_size x num_queries x out_h x out_w]
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
# Backbone
if not isinstance(samples, NestedTensor):
samples = nested_tensor_from_videos_list(samples)
# features (list[NestedTensor]): res2 -> res5, shape of tensors is [B*T, Ci, Hi, Wi]
# pos (list[Tensor]): shape of [B*T, C, Hi, Wi]
features, pos = self.backbone(samples)
b = len(captions)
t = pos[0].shape[0] // b
# For A2D-Sentences and JHMDB-Sentencs dataset, only one frame is annotated for a clip
if 'valid_indices' in targets[0]:
valid_indices = torch.tensor([i * t + target['valid_indices'] for i, target in enumerate(targets)]).to(pos[0].device)
for feature in features:
feature.tensors = feature.tensors.index_select(0, valid_indices)
feature.mask = feature.mask.index_select(0, valid_indices)
for i, p in enumerate(pos):
pos[i] = p.index_select(0, valid_indices)
samples.mask = samples.mask.index_select(0, valid_indices)
# t: num_frames -> 1
t = 1
text_features, text_sentence_features = self.forward_text(captions, device=pos[0].device)
# prepare vision and text features for transformer
srcs = []
masks = []
poses = []
text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c]
text_word_features, text_word_masks = text_features.decompose()
text_word_features = text_word_features.permute(1, 0, 2) # [length, batch_size, c]
# Follow Deformable-DETR, we use the last three stages outputs from backbone
for l, (feat, pos_l) in enumerate(zip(features[-3:], pos[-3:])):
src, mask = feat.decompose()
src_proj_l = self.input_proj[l](src)
n, c, h, w = src_proj_l.shape
# vision language early-fusion
src_proj_l = rearrange(src_proj_l, '(b t) c h w -> (t h w) b c', b=b, t=t)
src_proj_l = self.fusion_module(tgt=src_proj_l,
memory=text_word_features,
memory_key_padding_mask=text_word_masks,
pos=text_pos,
query_pos=None
)
src_proj_l = rearrange(src_proj_l, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
srcs.append(src_proj_l)
masks.append(mask)
poses.append(pos_l)
assert mask is not None
if self.num_feature_levels > (len(features) - 1):
_len_srcs = len(features) - 1 # fpn level
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
n, c, h, w = src.shape
# vision language early-fusion
src = rearrange(src, '(b t) c h w -> (t h w) b c', b=b, t=t)
src = self.fusion_module(tgt=src,
memory=text_word_features,
memory_key_padding_mask=text_word_masks,
pos=text_pos,
query_pos=None
)
src = rearrange(src, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
srcs.append(src)
masks.append(mask)
poses.append(pos_l)
# Transformer
query_embeds = self.query_embed.weight # [num_queries, c]
text_embed = repeat(text_sentence_features, 'b c -> b t q c', t=t, q=self.num_queries)
hs, memory, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, inter_samples = \
self.transformer(srcs, text_embed, masks, poses, query_embeds)
# hs: [l, batch_size*time, num_queries_per_frame, c]
# memory: list[Tensor], shape of tensor is [batch_size*time, c, hi, wi]
# init_reference: [batch_size*time, num_queries_per_frame, 2]
# inter_references: [l, batch_size*time, num_queries_per_frame, 4]
out = {}
# prediction
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[lvl](hs[lvl])
tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
outputs_coord = tmp.sigmoid() # cxcywh, range in [0,1]
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)
# rearrange
outputs_class = rearrange(outputs_class, 'l (b t) q k -> l b t q k', b=b, t=t)
outputs_coord = rearrange(outputs_coord, 'l (b t) q n -> l b t q n', b=b, t=t)
out['pred_logits'] = outputs_class[-1] # [batch_size, time, num_queries_per_frame, num_classes]
out['pred_boxes'] = outputs_coord[-1] # [batch_size, time, num_queries_per_frame, 4]
# Segmentation
mask_features = self.pixel_decoder(features, text_features, pos, memory, nf=t) # [batch_size*time, c, out_h, out_w]
mask_features = rearrange(mask_features, '(b t) c h w -> b t c h w', b=b, t=t)
# dynamic conv
outputs_seg_masks = []
for lvl in range(hs.shape[0]):
dynamic_mask_head_params = self.controller(hs[lvl]) # [batch_size*time, num_queries_per_frame, num_params]
dynamic_mask_head_params = rearrange(dynamic_mask_head_params, '(b t) q n -> b (t q) n', b=b, t=t)
lvl_references = inter_references[lvl, ..., :2]
lvl_references = rearrange(lvl_references, '(b t) q n -> b (t q) n', b=b, t=t)
outputs_seg_mask = self.dynamic_mask_with_coords(mask_features, dynamic_mask_head_params, lvl_references, targets)
outputs_seg_mask = rearrange(outputs_seg_mask, 'b (t q) h w -> b t q h w', t=t)
outputs_seg_masks.append(outputs_seg_mask)
out['pred_masks'] = outputs_seg_masks[-1] # [batch_size, time, num_queries_per_frame, out_h, out_w]
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_seg_masks)
if not self.training:
# for visualization
inter_references = inter_references[-2, :, :, :2] # [batch_size*time, num_queries_per_frame, 2]
inter_references = rearrange(inter_references, '(b t) q n -> b t q n', b=b, t=t)
out['reference_points'] = inter_references # the reference points of last layer input
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord, outputs_seg_masks):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"pred_logits": a, "pred_boxes": b, "pred_masks": c}
for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_seg_masks[:-1])]
def forward_text(self, captions, device):
if isinstance(captions[0], str):
tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt").to(device)
encoded_text = self.text_encoder(**tokenized)
# encoded_text.last_hidden_state: [batch_size, length, 768]
# encoded_text.pooler_output: [batch_size, 768]
text_attention_mask = tokenized.attention_mask.ne(1).bool()
# text_attention_mask: [batch_size, length]
text_features = encoded_text.last_hidden_state
text_features = self.resizer(text_features)
text_masks = text_attention_mask
text_features = NestedTensor(text_features, text_masks) # NestedTensor
text_sentence_features = encoded_text.pooler_output
text_sentence_features = self.resizer(text_sentence_features)
else:
raise ValueError("Please mask sure the caption is a list of string")
return text_features, text_sentence_features
def dynamic_mask_with_coords(self, mask_features, mask_head_params, reference_points, targets):
"""
Add the relative coordinates to the mask_features channel dimension,
and perform dynamic mask conv.
Args:
mask_features: [batch_size, time, c, h, w]
mask_head_params: [batch_size, time * num_queries_per_frame, num_params]
reference_points: [batch_size, time * num_queries_per_frame, 2], cxcy
targets (list[dict]): length is batch size
we need the key 'size' for computing location.
Return:
outputs_seg_mask: [batch_size, time * num_queries_per_frame, h, w]
"""
device = mask_features.device
b, t, c, h, w = mask_features.shape
# this is the total query number in all frames
_, num_queries = reference_points.shape[:2]
q = num_queries // t # num_queries_per_frame
# prepare reference points in image size (the size is input size to the model)
new_reference_points = []
for i in range(b):
img_h, img_w = targets[i]['size']
scale_f = torch.stack([img_w, img_h], dim=0)
tmp_reference_points = reference_points[i] * scale_f[None, :]
new_reference_points.append(tmp_reference_points)
new_reference_points = torch.stack(new_reference_points, dim=0)
# [batch_size, time * num_queries_per_frame, 2], in image size
reference_points = new_reference_points
# prepare the mask features
if self.rel_coord:
reference_points = rearrange(reference_points, 'b (t q) n -> b t q n', t=t, q=q)
locations = compute_locations(h, w, device=device, stride=self.mask_feat_stride)
relative_coords = reference_points.reshape(b, t, q, 1, 1, 2) - \
locations.reshape(1, 1, 1, h, w, 2) # [batch_size, time, num_queries_per_frame, h, w, 2]
relative_coords = relative_coords.permute(0, 1, 2, 5, 3, 4) # [batch_size, time, num_queries_per_frame, 2, h, w]
# concat features
mask_features = repeat(mask_features, 'b t c h w -> b t q c h w', q=q) # [batch_size, time, num_queries_per_frame, c, h, w]
mask_features = torch.cat([mask_features, relative_coords], dim=3)
else:
mask_features = repeat(mask_features, 'b t c h w -> b t q c h w', q=q) # [batch_size, time, num_queries_per_frame, c, h, w]
mask_features = mask_features.reshape(1, -1, h, w)
# parse dynamic params
mask_head_params = mask_head_params.flatten(0, 1)
weights, biases = parse_dynamic_params(
mask_head_params, self.dynamic_mask_channels,
self.weight_nums, self.bias_nums
)
# dynamic mask conv
mask_logits = self.mask_heads_forward(mask_features, weights, biases, mask_head_params.shape[0])
mask_logits = mask_logits.reshape(-1, 1, h, w)
# upsample predicted masks
assert self.mask_feat_stride >= self.mask_out_stride
assert self.mask_feat_stride % self.mask_out_stride == 0
mask_logits = aligned_bilinear(mask_logits, int(self.mask_feat_stride / self.mask_out_stride))
mask_logits = mask_logits.reshape(b, num_queries, mask_logits.shape[-2], mask_logits.shape[-1])
return mask_logits # [batch_size, time * num_queries_per_frame, h, w]
def mask_heads_forward(self, features, weights, biases, num_insts):
'''
:param features
:param weights: [w0, w1, ...]
:param bias: [b0, b1, ...]
:return:
'''
assert features.dim() == 4
n_layers = len(weights)
x = features
for i, (w, b) in enumerate(zip(weights, biases)):
x = F.conv2d(
x, w, bias=b,
stride=1, padding=0,
groups=num_insts
)
if i < n_layers - 1:
x = F.relu(x)
return x
def parse_dynamic_params(params, channels, weight_nums, bias_nums):
assert params.dim() == 2
assert len(weight_nums) == len(bias_nums)
assert params.size(1) == sum(weight_nums) + sum(bias_nums)
num_insts = params.size(0)
num_layers = len(weight_nums)
params_splits = list(torch.split_with_sizes(params, weight_nums + bias_nums, dim=1))
weight_splits = params_splits[:num_layers]
bias_splits = params_splits[num_layers:]
for l in range(num_layers):
if l < num_layers - 1:
# out_channels x in_channels x 1 x 1
weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
else:
# out_channels x in_channels x 1 x 1
weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts)
return weight_splits, bias_splits
def aligned_bilinear(tensor, factor):
assert tensor.dim() == 4
assert factor >= 1
assert int(factor) == factor
if factor == 1:
return tensor
h, w = tensor.size()[2:]
tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate")
oh = factor * h + 1
ow = factor * w + 1
tensor = F.interpolate(
tensor, size=(oh, ow),
mode='bilinear',
align_corners=True
)
tensor = F.pad(
tensor, pad=(factor // 2, 0, factor // 2, 0),
mode="replicate"
)
return tensor[:, :, :oh - 1, :ow - 1]
def compute_locations(h, w, device, stride=1):
shifts_x = torch.arange(
0, w * stride, step=stride,
dtype=torch.float32, device=device)
shifts_y = torch.arange(
0, h * stride, step=stride,
dtype=torch.float32, device=device)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
return locations
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class FeatureResizer(nn.Module):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
super().__init__()
self.do_ln = do_ln
# Object feature encoding
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(self, encoder_features):
x = self.fc(encoder_features)
if self.do_ln:
x = self.layer_norm(x)
output = self.dropout(x)
return output
def build(args):
if args.binary:
num_classes = 1
else:
if args.dataset_file == 'ytvos':
num_classes = 65
elif args.dataset_file == 'davis':
num_classes = 78
elif args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb':
num_classes = 1
else:
num_classes = 91 # for coco
device = torch.device(args.device)
# backbone
if 'video_swin' in args.backbone:
from .video_swin_transformer import build_video_swin_backbone
backbone = build_video_swin_backbone(args)
elif 'swin' in args.backbone:
from .swin_transformer import build_swin_backbone
backbone = build_swin_backbone(args)
else:
backbone = build_backbone(args)
transformer = build_deforamble_transformer(args)
model = ReferFormer(
backbone,
transformer,
num_classes=num_classes,
num_queries=args.num_queries,
num_feature_levels=args.num_feature_levels,
num_frames=args.num_frames,
mask_dim=args.mask_dim,
dim_feedforward=args.dim_feedforward,
controller_layers=args.controller_layers,
dynamic_mask_channels=args.dynamic_mask_channels,
aux_loss=args.aux_loss,
with_box_refine=args.with_box_refine,
two_stage=args.two_stage,
freeze_text_encoder=args.freeze_text_encoder,
rel_coord=args.rel_coord
)
matcher = build_matcher(args)
weight_dict = {}
weight_dict['loss_ce'] = args.cls_loss_coef
weight_dict['loss_bbox'] = args.bbox_loss_coef
weight_dict['loss_giou'] = args.giou_loss_coef
if args.masks: # always true
weight_dict['loss_mask'] = args.mask_loss_coef
weight_dict['loss_dice'] = args.dice_loss_coef
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ['labels', 'boxes']
if args.masks:
losses += ['masks']
criterion = SetCriterion(
num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=args.eos_coef,
losses=losses,
focal_alpha=args.focal_alpha)
criterion.to(device)
# postprocessors, this is used for coco pretrain but not for rvos
postprocessors = build_postprocessors(args, args.dataset_file)
return model, criterion, postprocessors