Spaces:
Runtime error
Runtime error
import math | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from util.misc import (NestedTensor, inverse_sigmoid, | |
nested_tensor_from_tensor_list) | |
from .blip2_decoder import BLIP2Decoder | |
from .deformable_detr.backbone import build_backbone | |
from .deformable_detr.deformable_detr import DeformableDETR | |
from .transformer import build_ov_transformer | |
class ContextDET(DeformableDETR): | |
def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, | |
aux_loss=True, with_box_refine=False, two_stage=False, llm_decoder=None): | |
super().__init__(backbone, transformer, num_classes, num_queries, num_feature_levels, | |
aux_loss, with_box_refine, two_stage) | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.llm_decoder = llm_decoder | |
hidden_dim = transformer.d_model | |
out_size = self.llm_decoder.model.opt_proj.out_features | |
self.llm_proj = nn.Linear(out_size, hidden_dim, device=self.device) | |
self.start_end_proj = nn.Linear(hidden_dim, 2) | |
for layer in [self.llm_proj, self.start_end_proj]: | |
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') | |
nn.init.zeros_(layer.bias) | |
# word_embed_proj_dim = llm_decoder.model.opt_model.config.word_embed_proj_dim | |
vocab_size = llm_decoder.model.opt_model.config.vocab_size | |
self.fc_logits = nn.Linear(hidden_dim, vocab_size) | |
def forward(self, samples, blip2_samples, mask_infos=None, task_button=None, threshold=0.3): | |
logits, hidden_states, input_ids, output_text = self.llm_decoder.model.forward( | |
blip2_samples, task_button=task_button) | |
hidden_states = hidden_states.detach() | |
hidden_states = self.llm_proj(hidden_states) | |
if not isinstance(samples, NestedTensor): | |
samples = nested_tensor_from_tensor_list(samples) | |
features, pos = self.backbone(samples) | |
srcs = [] | |
masks = [] | |
for l, feat in enumerate(features): | |
src, mask = feat.decompose() | |
srcs.append(self.input_proj[l](src)) | |
masks.append(mask) | |
assert mask is not None | |
if self.num_feature_levels > len(srcs): | |
_len_srcs = len(srcs) | |
for l in range(_len_srcs, self.num_feature_levels): | |
if l == _len_srcs: | |
src = self.input_proj[l](features[-1].tensors) | |
else: | |
src = self.input_proj[l](srcs[-1]) | |
m = samples.mask | |
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] | |
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
srcs.append(src) | |
masks.append(mask) | |
pos.append(pos_l) | |
out = {} | |
start_end_proj = self.start_end_proj(hidden_states) | |
out['pred_mlm_logits'] = self.fc_logits(hidden_states) | |
out['pred_start'] = start_end_proj[:, :, 0:1] | |
out['pred_end'] = start_end_proj[:, :, 1:2] | |
out['output_text'] = output_text | |
if self.training: | |
k = min([len(mask_info) for mask_info in mask_infos]) | |
k = min(k, 2) | |
select_ids = [random.sample(mask_info.keys(), k) for mask_info in mask_infos] | |
# select_ids = [random.choices(list(mask_info.keys()), k=4) for mask_info in mask_infos] | |
llm_feat = [] | |
for b in range(len(select_ids)): | |
llm_feat_b = [] | |
hidden_states_b = hidden_states[b, :, :] | |
for start, end in select_ids[b]: | |
llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True)) | |
llm_feat.append(torch.cat(llm_feat_b)[None]) | |
llm_feat = torch.cat(llm_feat) | |
query_embeds = None | |
if not self.two_stage: | |
query_embeds = self.query_embed.weight | |
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = ( | |
self.transformer(srcs, masks, pos, query_embeds, llm_feat, k) | |
) | |
outputs_classes = [] | |
outputs_coords = [] | |
for lvl in range(hs.shape[0]): | |
if lvl == 0: | |
reference = init_reference | |
else: | |
reference = inter_references[lvl - 1] | |
reference = inverse_sigmoid(reference) | |
outputs_class = self.class_embed[lvl](hs[lvl]) | |
tmp = self.bbox_embed[lvl](hs[lvl]) | |
if reference.shape[-1] == 4: | |
tmp += reference | |
else: | |
assert reference.shape[-1] == 2 | |
tmp[..., :2] += reference | |
outputs_coord = tmp.sigmoid() | |
outputs_classes.append(outputs_class) | |
outputs_coords.append(outputs_coord) | |
outputs_class = torch.stack(outputs_classes) | |
outputs_coord = torch.stack(outputs_coords) | |
out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], | |
'init_reference': init_reference}) | |
out['select_ids'] = select_ids | |
if self.aux_loss: | |
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) | |
for temp in out["aux_outputs"]: | |
temp["select_ids"] = select_ids | |
if self.two_stage: | |
enc_outputs_coord = enc_outputs_coord_unact.sigmoid() | |
out['enc_outputs'] = { | |
'pred_logits': enc_outputs_class, | |
'pred_boxes': enc_outputs_coord, | |
'anchors': anchors, | |
} | |
else: | |
bs = len(samples.tensors) | |
mask_infos_pred = [{} for _ in range(bs)] | |
llm_feat = [] | |
tokenizer = self.llm_decoder.model.opt_tokenizer | |
if mask_infos is None: | |
if task_button == 'Cloze Test': | |
mask_infos = [] | |
output_texts = [] | |
for b in range(bs): | |
mask_infos_b = {} | |
output_texts_b = [] | |
for ind, token in enumerate(input_ids[b]): | |
if token == tokenizer.mask_token_id: | |
mask_infos_b[(ind, ind)] = '' | |
pred_token = out['pred_mlm_logits'][b, ind:ind + 1, :] | |
pred_token = pred_token.argmax(1).item() | |
output_texts_b.append( pred_token ) | |
output_texts_b.append( 1437 ) | |
input_ids[b, ind: ind + 1] = pred_token | |
else: | |
output_texts_b.append( token.item() ) | |
mask_infos.append(mask_infos_b) | |
output_texts.append(tokenizer.decode(output_texts_b[1:])) | |
out['output_text'] = output_texts | |
else: | |
mask_infos = [] | |
for b in range(bs): | |
starts = (out['pred_start'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1) | |
ends = (out['pred_end'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1) | |
if len(starts) == 0: | |
starts = out['pred_start'][b, :].argmax(0) | |
if len(ends) == 0: | |
ends = out['pred_end'][b, :].argmax(0) | |
mask_infos_b = {} | |
for start, end in zip(starts, ends): | |
mask_infos_b[(int(start), int(end))] = '' | |
mask_infos.append(mask_infos_b) | |
for b in range(bs): | |
llm_feat_b = [] | |
hidden_states_b = hidden_states[b, :, :] | |
for start, end in mask_infos[b].keys(): | |
llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True)) | |
pred_name = tokenizer.decode(input_ids[b, start: end + 1]).strip() | |
mask_infos_pred[b][(int(start), int(end))] = pred_name | |
llm_feat.append(torch.cat(llm_feat_b)[None]) | |
out['mask_infos_pred'] = mask_infos_pred | |
query_embeds = None | |
if not self.two_stage: | |
query_embeds = self.query_embed.weight | |
outputs_classes_list = [] | |
outputs_coords_list = [] | |
for b in range(bs): | |
srcs_b = [i[b: b + 1] for i in srcs] | |
masks_b = [i[b: b + 1] for i in masks] | |
pos_b = [i[b: b + 1] for i in pos] | |
k = len(mask_infos[b]) | |
if k == 0: | |
outputs_classes_list.append(torch.zeros(0, 2).to(self.device)) | |
outputs_coords_list.append(torch.zeros(0, 4).to(self.device)) | |
continue | |
num_repeat = math.ceil(k / 4) | |
outputs_classes = [] | |
outputs_coords = [] | |
for ind in range(num_repeat): | |
llm_feat_b = llm_feat[b][:, ind * 4: (ind + 1) * 4] | |
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = ( | |
self.transformer(srcs_b, masks_b, pos_b, query_embeds, llm_feat_b, llm_feat_b.shape[1]) | |
) | |
lvl = hs.shape[0] - 1 | |
reference = inter_references[lvl - 1] | |
reference = inverse_sigmoid(reference) | |
outputs_class = self.class_embed[lvl](hs[lvl]) | |
tmp = self.bbox_embed[lvl](hs[lvl]) | |
if reference.shape[-1] == 4: | |
tmp += reference | |
else: | |
assert reference.shape[-1] == 2 | |
tmp[..., :2] += reference | |
outputs_coord = tmp.sigmoid() | |
outputs_classes.append(outputs_class.flatten(0, 1)) | |
outputs_coords.append(outputs_coord.flatten(0, 1)) | |
outputs_classes = torch.cat(outputs_classes)[None] | |
outputs_coords = torch.cat(outputs_coords)[None] | |
outputs_classes_list.append(outputs_classes) | |
outputs_coords_list.append(outputs_coords) | |
out.update({'pred_logits': outputs_classes_list, | |
'pred_boxes': outputs_coords_list}) | |
return out |