Spaces:
Runtime error
Runtime error
import torch | |
from torchvision.ops.boxes import batched_nms | |
from util.box_ops import box_cxcywh_to_xyxy | |
from .deformable_detr.deformable_transformer import DeformableTransformer | |
class OVTransformer(DeformableTransformer): | |
def __init__(self, d_model=256, nhead=8, | |
num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, | |
activation="relu", return_intermediate_dec=False, | |
num_feature_levels=4, dec_n_points=4, enc_n_points=4, | |
two_stage=False, two_stage_num_proposals=300, | |
assign_first_stage=False): | |
super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, | |
activation, return_intermediate_dec, num_feature_levels, dec_n_points, enc_n_points, | |
two_stage, two_stage_num_proposals, assign_first_stage) | |
def forward(self, srcs, masks, pos_embeds, query_embed=None, llm_feat=None, num_patch=1): | |
assert self.two_stage or query_embed is not None | |
# prepare input for encoder | |
src_flatten = [] | |
mask_flatten = [] | |
lvl_pos_embed_flatten = [] | |
spatial_shapes = [] | |
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): | |
bs, c, h, w = src.shape | |
spatial_shape = (h, w) | |
spatial_shapes.append(spatial_shape) | |
src = src.flatten(2).transpose(1, 2) | |
mask = mask.flatten(1) | |
pos_embed = pos_embed.flatten(2).transpose(1, 2) | |
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) | |
lvl_pos_embed_flatten.append(lvl_pos_embed) | |
src_flatten.append(src) | |
mask_flatten.append(mask) | |
src_flatten = torch.cat(src_flatten, 1) | |
mask_flatten = torch.cat(mask_flatten, 1) | |
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) | |
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) | |
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
# encoder | |
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, | |
lvl_pos_embed_flatten, mask_flatten) | |
# prepare input for decoder | |
bs, _, c = memory.shape | |
if self.two_stage: | |
output_memory, output_proposals, level_ids = \ | |
self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) | |
# hack implementation for two-stage Deformable DETR | |
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) | |
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals | |
topk = self.two_stage_num_proposals | |
proposal_logit = enc_outputs_class[..., 0] | |
if self.assign_first_stage: | |
proposal_boxes = box_cxcywh_to_xyxy(enc_outputs_coord_unact.sigmoid().float()).clamp(0, 1) | |
topk_proposals = [] | |
for b in range(bs): | |
prop_boxes_b = proposal_boxes[b] | |
prop_logits_b = proposal_logit[b] | |
# pre-nms per-level topk | |
pre_nms_topk = 1000 | |
pre_nms_inds = [] | |
for lvl in range(len(spatial_shapes)): | |
lvl_mask = level_ids == lvl | |
pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1]) | |
pre_nms_inds = torch.cat(pre_nms_inds) | |
# nms on topk indices | |
post_nms_inds = batched_nms(prop_boxes_b[pre_nms_inds], | |
prop_logits_b[pre_nms_inds], | |
level_ids[pre_nms_inds], 0.9) | |
keep_inds = pre_nms_inds[post_nms_inds] | |
if len(keep_inds) < self.two_stage_num_proposals: | |
print(f'[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}') | |
keep_inds = torch.topk(proposal_logit[b], topk)[1] | |
# keep top Q/L indices for L levels | |
q_per_l = topk // len(spatial_shapes) | |
level_shapes = torch.arange(len(spatial_shapes), device=level_ids.device)[:, None] | |
is_level_ordered = level_ids[keep_inds][None] == level_shapes | |
keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS | |
keep_inds_mask = keep_inds_mask.any(0) # S | |
# pad to Q indices (might let ones filtered from pre-nms sneak by... | |
# unlikely because we pick high conf anyways) | |
if keep_inds_mask.sum() < topk: | |
num_to_add = topk - keep_inds_mask.sum() | |
pad_inds = (~keep_inds_mask).nonzero()[:num_to_add] | |
keep_inds_mask[pad_inds] = True | |
# index | |
keep_inds_topk = keep_inds[keep_inds_mask] | |
topk_proposals.append(keep_inds_topk) | |
topk_proposals = torch.stack(topk_proposals) | |
else: | |
topk_proposals = torch.topk(proposal_logit, topk, dim=1)[1] | |
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) | |
topk_coords_unact = topk_coords_unact.detach() | |
reference_points = topk_coords_unact.sigmoid() | |
init_reference_out = reference_points | |
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) | |
query_embed, tgt = torch.split(pos_trans_out, c, dim=2) | |
num_queries = query_embed.shape[1] | |
query_embed = query_embed.repeat(1, num_patch, 1) | |
tgt = tgt.repeat(1, num_patch, 1) | |
topk_feats = torch.stack([output_memory[b][topk_proposals[b]] for b in range(bs)]).detach() | |
topk_feats = topk_feats.repeat(1, num_patch, 1) | |
tgt = tgt + self.pix_trans_norm(self.pix_trans(topk_feats)) | |
reference_points = reference_points.repeat(1, num_patch, 1) | |
init_reference_out = init_reference_out.repeat(1, num_patch, 1) | |
llm_feat = llm_feat.repeat_interleave(num_queries, 1) | |
tgt = tgt + llm_feat | |
else: | |
raise NotImplementedError | |
query_embed, tgt = torch.split(query_embed, c, dim=1) | |
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) | |
tgt = tgt.unsqueeze(0).expand(bs, -1, -1) | |
reference_points = self.reference_points(query_embed).sigmoid() | |
init_reference_out = reference_points | |
# decoder mask | |
decoder_mask = ( | |
torch.ones( | |
num_queries * num_patch, | |
num_queries * num_patch, | |
device=query_embed.device, | |
) * float("-inf") | |
) | |
for i in range(num_patch): | |
decoder_mask[ | |
i * num_queries : (i + 1) * num_queries, | |
i * num_queries : (i + 1) * num_queries, | |
] = 0 | |
# decoder | |
hs, inter_references = self.decoder(tgt, reference_points, memory, | |
spatial_shapes, level_start_index, valid_ratios, | |
query_embed, mask_flatten, tgt_mask=decoder_mask) | |
inter_references_out = inter_references | |
if self.two_stage: | |
return (hs, | |
init_reference_out, | |
inter_references_out, | |
enc_outputs_class, | |
enc_outputs_coord_unact, | |
output_proposals.sigmoid()) | |
return hs, init_reference_out, inter_references_out, None, None, None | |
def build_ov_transformer(args): | |
return OVTransformer( | |
d_model=args.hidden_dim, | |
nhead=args.nheads, | |
num_encoder_layers=args.enc_layers, | |
num_decoder_layers=args.dec_layers, | |
dim_feedforward=args.dim_feedforward, | |
dropout=args.dropout, | |
activation="relu", | |
return_intermediate_dec=True, | |
num_feature_levels=args.num_feature_levels, | |
dec_n_points=args.dec_n_points, | |
enc_n_points=args.enc_n_points, | |
two_stage=args.two_stage, | |
two_stage_num_proposals=args.num_queries, | |
assign_first_stage=args.assign_first_stage) | |