|
import math |
|
import random |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from util.misc import (NestedTensor, inverse_sigmoid, |
|
nested_tensor_from_tensor_list) |
|
|
|
from .blip2_decoder import BLIP2Decoder |
|
from .deformable_detr.backbone import build_backbone |
|
from .deformable_detr.deformable_detr import DeformableDETR |
|
from .transformer import build_ov_transformer |
|
|
|
|
|
class ContextDET(DeformableDETR): |
|
def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, |
|
aux_loss=True, with_box_refine=False, two_stage=False, llm_decoder=None): |
|
super().__init__(backbone, transformer, num_classes, num_queries, num_feature_levels, |
|
aux_loss, with_box_refine, two_stage) |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.llm_decoder = llm_decoder |
|
hidden_dim = transformer.d_model |
|
out_size = self.llm_decoder.model.opt_proj.out_features |
|
self.llm_proj = nn.Linear(out_size, hidden_dim, device=self.device) |
|
self.start_end_proj = nn.Linear(hidden_dim, 2) |
|
for layer in [self.llm_proj, self.start_end_proj]: |
|
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
|
nn.init.zeros_(layer.bias) |
|
|
|
vocab_size = llm_decoder.model.opt_model.config.vocab_size |
|
self.fc_logits = nn.Linear(hidden_dim, vocab_size) |
|
|
|
def forward(self, samples, blip2_samples, mask_infos=None, task_button=None, threshold=0.3): |
|
logits, hidden_states, input_ids, output_text = self.llm_decoder.model.forward( |
|
blip2_samples, task_button=task_button) |
|
hidden_states = hidden_states.detach() |
|
hidden_states = self.llm_proj(hidden_states) |
|
|
|
if not isinstance(samples, NestedTensor): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
features, pos = self.backbone(samples) |
|
|
|
srcs = [] |
|
masks = [] |
|
for l, feat in enumerate(features): |
|
src, mask = feat.decompose() |
|
srcs.append(self.input_proj[l](src)) |
|
masks.append(mask) |
|
assert mask is not None |
|
if self.num_feature_levels > len(srcs): |
|
_len_srcs = len(srcs) |
|
for l in range(_len_srcs, self.num_feature_levels): |
|
if l == _len_srcs: |
|
src = self.input_proj[l](features[-1].tensors) |
|
else: |
|
src = self.input_proj[l](srcs[-1]) |
|
m = samples.mask |
|
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] |
|
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) |
|
srcs.append(src) |
|
masks.append(mask) |
|
pos.append(pos_l) |
|
|
|
out = {} |
|
start_end_proj = self.start_end_proj(hidden_states) |
|
out['pred_mlm_logits'] = self.fc_logits(hidden_states) |
|
out['pred_start'] = start_end_proj[:, :, 0:1] |
|
out['pred_end'] = start_end_proj[:, :, 1:2] |
|
out['output_text'] = output_text |
|
if self.training: |
|
k = min([len(mask_info) for mask_info in mask_infos]) |
|
k = min(k, 2) |
|
select_ids = [random.sample(mask_info.keys(), k) for mask_info in mask_infos] |
|
|
|
llm_feat = [] |
|
for b in range(len(select_ids)): |
|
llm_feat_b = [] |
|
hidden_states_b = hidden_states[b, :, :] |
|
for start, end in select_ids[b]: |
|
llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True)) |
|
llm_feat.append(torch.cat(llm_feat_b)[None]) |
|
llm_feat = torch.cat(llm_feat) |
|
query_embeds = None |
|
if not self.two_stage: |
|
query_embeds = self.query_embed.weight |
|
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = ( |
|
self.transformer(srcs, masks, pos, query_embeds, llm_feat, k) |
|
) |
|
outputs_classes = [] |
|
outputs_coords = [] |
|
for lvl in range(hs.shape[0]): |
|
if lvl == 0: |
|
reference = init_reference |
|
else: |
|
reference = inter_references[lvl - 1] |
|
reference = inverse_sigmoid(reference) |
|
outputs_class = self.class_embed[lvl](hs[lvl]) |
|
tmp = self.bbox_embed[lvl](hs[lvl]) |
|
if reference.shape[-1] == 4: |
|
tmp += reference |
|
else: |
|
assert reference.shape[-1] == 2 |
|
tmp[..., :2] += reference |
|
outputs_coord = tmp.sigmoid() |
|
outputs_classes.append(outputs_class) |
|
outputs_coords.append(outputs_coord) |
|
outputs_class = torch.stack(outputs_classes) |
|
outputs_coord = torch.stack(outputs_coords) |
|
|
|
out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], |
|
'init_reference': init_reference}) |
|
out['select_ids'] = select_ids |
|
|
|
if self.aux_loss: |
|
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) |
|
for temp in out["aux_outputs"]: |
|
temp["select_ids"] = select_ids |
|
|
|
if self.two_stage: |
|
enc_outputs_coord = enc_outputs_coord_unact.sigmoid() |
|
out['enc_outputs'] = { |
|
'pred_logits': enc_outputs_class, |
|
'pred_boxes': enc_outputs_coord, |
|
'anchors': anchors, |
|
} |
|
else: |
|
bs = len(samples.tensors) |
|
mask_infos_pred = [{} for _ in range(bs)] |
|
llm_feat = [] |
|
tokenizer = self.llm_decoder.model.opt_tokenizer |
|
if mask_infos is None: |
|
if task_button == 'Cloze Test': |
|
mask_infos = [] |
|
output_texts = [] |
|
for b in range(bs): |
|
mask_infos_b = {} |
|
output_texts_b = [] |
|
for ind, token in enumerate(input_ids[b]): |
|
if token == tokenizer.mask_token_id: |
|
mask_infos_b[(ind, ind)] = '' |
|
pred_token = out['pred_mlm_logits'][b, ind:ind + 1, :] |
|
pred_token = pred_token.argmax(1).item() |
|
output_texts_b.append( pred_token ) |
|
output_texts_b.append( 1437 ) |
|
input_ids[b, ind: ind + 1] = pred_token |
|
else: |
|
output_texts_b.append( token.item() ) |
|
mask_infos.append(mask_infos_b) |
|
output_texts.append(tokenizer.decode(output_texts_b[1:])) |
|
out['output_text'] = output_texts |
|
else: |
|
mask_infos = [] |
|
for b in range(bs): |
|
starts = (out['pred_start'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1) |
|
ends = (out['pred_end'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1) |
|
if len(starts) == 0: |
|
starts = out['pred_start'][b, :].argmax(0) |
|
if len(ends) == 0: |
|
ends = out['pred_end'][b, :].argmax(0) |
|
mask_infos_b = {} |
|
for start, end in zip(starts, ends): |
|
mask_infos_b[(int(start), int(end))] = '' |
|
mask_infos.append(mask_infos_b) |
|
for b in range(bs): |
|
llm_feat_b = [] |
|
hidden_states_b = hidden_states[b, :, :] |
|
for start, end in mask_infos[b].keys(): |
|
llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True)) |
|
pred_name = tokenizer.decode(input_ids[b, start: end + 1]).strip() |
|
mask_infos_pred[b][(int(start), int(end))] = pred_name |
|
llm_feat.append(torch.cat(llm_feat_b)[None]) |
|
out['mask_infos_pred'] = mask_infos_pred |
|
|
|
query_embeds = None |
|
if not self.two_stage: |
|
query_embeds = self.query_embed.weight |
|
|
|
outputs_classes_list = [] |
|
outputs_coords_list = [] |
|
for b in range(bs): |
|
srcs_b = [i[b: b + 1] for i in srcs] |
|
masks_b = [i[b: b + 1] for i in masks] |
|
pos_b = [i[b: b + 1] for i in pos] |
|
k = len(mask_infos[b]) |
|
if k == 0: |
|
outputs_classes_list.append(torch.zeros(0, 2).to(self.device)) |
|
outputs_coords_list.append(torch.zeros(0, 4).to(self.device)) |
|
continue |
|
num_repeat = math.ceil(k / 4) |
|
outputs_classes = [] |
|
outputs_coords = [] |
|
for ind in range(num_repeat): |
|
llm_feat_b = llm_feat[b][:, ind * 4: (ind + 1) * 4] |
|
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = ( |
|
self.transformer(srcs_b, masks_b, pos_b, query_embeds, llm_feat_b, llm_feat_b.shape[1]) |
|
) |
|
lvl = hs.shape[0] - 1 |
|
reference = inter_references[lvl - 1] |
|
reference = inverse_sigmoid(reference) |
|
outputs_class = self.class_embed[lvl](hs[lvl]) |
|
tmp = self.bbox_embed[lvl](hs[lvl]) |
|
if reference.shape[-1] == 4: |
|
tmp += reference |
|
else: |
|
assert reference.shape[-1] == 2 |
|
tmp[..., :2] += reference |
|
outputs_coord = tmp.sigmoid() |
|
outputs_classes.append(outputs_class.flatten(0, 1)) |
|
outputs_coords.append(outputs_coord.flatten(0, 1)) |
|
outputs_classes = torch.cat(outputs_classes)[None] |
|
outputs_coords = torch.cat(outputs_coords)[None] |
|
outputs_classes_list.append(outputs_classes) |
|
outputs_coords_list.append(outputs_coords) |
|
|
|
out.update({'pred_logits': outputs_classes_list, |
|
'pred_boxes': outputs_coords_list}) |
|
return out |