import torch from torchvision.ops.boxes import batched_nms from util.box_ops import box_cxcywh_to_xyxy from .deformable_detr.deformable_transformer import DeformableTransformer class OVTransformer(DeformableTransformer): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", return_intermediate_dec=False, num_feature_levels=4, dec_n_points=4, enc_n_points=4, two_stage=False, two_stage_num_proposals=300, assign_first_stage=False): super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation, return_intermediate_dec, num_feature_levels, dec_n_points, enc_n_points, two_stage, two_stage_num_proposals, assign_first_stage) def forward(self, srcs, masks, pos_embeds, query_embed=None, llm_feat=None, num_patch=1): assert self.two_stage or query_embed is not None # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) # prepare input for decoder bs, _, c = memory.shape if self.two_stage: output_memory, output_proposals, level_ids = \ self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) # hack implementation for two-stage Deformable DETR enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals topk = self.two_stage_num_proposals proposal_logit = enc_outputs_class[..., 0] if self.assign_first_stage: proposal_boxes = box_cxcywh_to_xyxy(enc_outputs_coord_unact.sigmoid().float()).clamp(0, 1) topk_proposals = [] for b in range(bs): prop_boxes_b = proposal_boxes[b] prop_logits_b = proposal_logit[b] # pre-nms per-level topk pre_nms_topk = 1000 pre_nms_inds = [] for lvl in range(len(spatial_shapes)): lvl_mask = level_ids == lvl pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1]) pre_nms_inds = torch.cat(pre_nms_inds) # nms on topk indices post_nms_inds = batched_nms(prop_boxes_b[pre_nms_inds], prop_logits_b[pre_nms_inds], level_ids[pre_nms_inds], 0.9) keep_inds = pre_nms_inds[post_nms_inds] if len(keep_inds) < self.two_stage_num_proposals: print(f'[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}') keep_inds = torch.topk(proposal_logit[b], topk)[1] # keep top Q/L indices for L levels q_per_l = topk // len(spatial_shapes) level_shapes = torch.arange(len(spatial_shapes), device=level_ids.device)[:, None] is_level_ordered = level_ids[keep_inds][None] == level_shapes keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS keep_inds_mask = keep_inds_mask.any(0) # S # pad to Q indices (might let ones filtered from pre-nms sneak by... # unlikely because we pick high conf anyways) if keep_inds_mask.sum() < topk: num_to_add = topk - keep_inds_mask.sum() pad_inds = (~keep_inds_mask).nonzero()[:num_to_add] keep_inds_mask[pad_inds] = True # index keep_inds_topk = keep_inds[keep_inds_mask] topk_proposals.append(keep_inds_topk) topk_proposals = torch.stack(topk_proposals) else: topk_proposals = torch.topk(proposal_logit, topk, dim=1)[1] topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) topk_coords_unact = topk_coords_unact.detach() reference_points = topk_coords_unact.sigmoid() init_reference_out = reference_points pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) query_embed, tgt = torch.split(pos_trans_out, c, dim=2) num_queries = query_embed.shape[1] query_embed = query_embed.repeat(1, num_patch, 1) tgt = tgt.repeat(1, num_patch, 1) topk_feats = torch.stack([output_memory[b][topk_proposals[b]] for b in range(bs)]).detach() topk_feats = topk_feats.repeat(1, num_patch, 1) tgt = tgt + self.pix_trans_norm(self.pix_trans(topk_feats)) reference_points = reference_points.repeat(1, num_patch, 1) init_reference_out = init_reference_out.repeat(1, num_patch, 1) llm_feat = llm_feat.repeat_interleave(num_queries, 1) tgt = tgt + llm_feat else: raise NotImplementedError query_embed, tgt = torch.split(query_embed, c, dim=1) query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) tgt = tgt.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_embed).sigmoid() init_reference_out = reference_points # decoder mask decoder_mask = ( torch.ones( num_queries * num_patch, num_queries * num_patch, device=query_embed.device, ) * float("-inf") ) for i in range(num_patch): decoder_mask[ i * num_queries : (i + 1) * num_queries, i * num_queries : (i + 1) * num_queries, ] = 0 # decoder hs, inter_references = self.decoder(tgt, reference_points, memory, spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten, tgt_mask=decoder_mask) inter_references_out = inter_references if self.two_stage: return (hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, output_proposals.sigmoid()) return hs, init_reference_out, inter_references_out, None, None, None def build_ov_transformer(args): return OVTransformer( d_model=args.hidden_dim, nhead=args.nheads, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, dim_feedforward=args.dim_feedforward, dropout=args.dropout, activation="relu", return_intermediate_dec=True, num_feature_levels=args.num_feature_levels, dec_n_points=args.dec_n_points, enc_n_points=args.enc_n_points, two_stage=args.two_stage, two_stage_num_proposals=args.num_queries, assign_first_stage=args.assign_first_stage)