|
import contextlib |
|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
from lavis.common.registry import registry |
|
from lavis.models import Blip2OPT, load_preprocess |
|
from omegaconf import OmegaConf |
|
|
|
|
|
@registry.register_model("blip2_opt_det") |
|
class Blip2OPTDet(Blip2OPT): |
|
def __init__( |
|
self, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.opt_tokenizer.add_special_tokens({"mask_token": "<mask>"}) |
|
|
|
def maybe_autocast(self, dtype=torch.float16): |
|
|
|
|
|
enable_autocast = self.device != torch.device("cpu") |
|
|
|
if enable_autocast: |
|
return torch.cuda.amp.autocast(dtype=dtype) |
|
else: |
|
return contextlib.nullcontext() |
|
|
|
@torch.no_grad() |
|
def forward(self, samples, |
|
use_nucleus_sampling=False, |
|
num_beams=5, |
|
max_length=30, |
|
min_length=1, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
length_penalty=1.0, |
|
num_captions=1, |
|
temperature=1, |
|
task_button=None): |
|
image = samples["image"] |
|
with self.maybe_autocast(): |
|
image_embeds = self.ln_vision(self.visual_encoder(image)) |
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( |
|
image.device |
|
) |
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_output = self.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
|
|
inputs_opt = self.opt_proj(query_output.last_hidden_state) |
|
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
self.opt_tokenizer.padding_side = "right" |
|
|
|
if "text_input" in samples.keys(): |
|
|
|
text = [t for t in samples["text_input"]] |
|
opt_tokens = self.opt_tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
).to(image.device) |
|
input_ids = opt_tokens.input_ids |
|
attention_mask = opt_tokens.attention_mask |
|
output_text = text |
|
elif "input_ids" in samples.keys(): |
|
input_ids = samples["input_ids"] |
|
attention_mask = samples["attention_mask"] |
|
output_text = [] |
|
else: |
|
assert "prompt" in samples.keys() |
|
prompt = samples["prompt"] |
|
assert len(prompt) == image.size(0) |
|
|
|
opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt", padding=True).to( |
|
image.device |
|
) |
|
input_ids = opt_tokens.input_ids |
|
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) |
|
|
|
if use_nucleus_sampling: |
|
query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0) |
|
num_beams = 1 |
|
else: |
|
query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0) |
|
|
|
with self.maybe_autocast(): |
|
outputs = self.opt_model.generate( |
|
input_ids=input_ids, |
|
query_embeds=query_embeds, |
|
attention_mask=attention_mask, |
|
do_sample=use_nucleus_sampling, |
|
top_p=top_p, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
max_new_tokens=max_length, |
|
min_length=min_length, |
|
eos_token_id=self.eos_token_id, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
num_return_sequences=num_captions, |
|
) |
|
|
|
prompt_length = opt_tokens.input_ids.shape[1] |
|
output_text = self.opt_tokenizer.batch_decode( |
|
outputs[:, prompt_length:], skip_special_tokens=True |
|
) |
|
output_text = [text.strip() for text in output_text] |
|
if task_button == 'Question Answering' or task_button == "Captioning": |
|
output_text_input = [prompt[0] + ' ' + output_text[0]] |
|
opt_tokens = self.opt_tokenizer( |
|
output_text_input, |
|
return_tensors="pt", |
|
padding="longest", |
|
).to(image.device) |
|
input_ids = opt_tokens.input_ids |
|
attention_mask = opt_tokens.attention_mask |
|
|
|
inputs_embeds = self.opt_model.model.decoder.embed_tokens(input_ids) |
|
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) |
|
attention_mask = torch.cat([atts_opt, attention_mask], dim=1) |
|
with self.maybe_autocast(): |
|
outputs = self.opt_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
output_hidden_states=True |
|
) |
|
n_queries = query_tokens.shape[1] |
|
out_logits = outputs['logits'][:, n_queries:] |
|
out_hidden = outputs['hidden_states'][-1][:, n_queries:] |
|
return out_logits, out_hidden, input_ids, output_text |
|
|
|
|
|
def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): |
|
model_cls = registry.get_model_class(name) |
|
|
|
|
|
model = model_cls.from_pretrained(model_type=model_type) |
|
|
|
if is_eval: |
|
model.eval() |
|
|
|
|
|
cfg = OmegaConf.load(model_cls.default_config_path(model_type)) |
|
if cfg is not None: |
|
preprocess_cfg = cfg.preprocess |
|
|
|
vis_processors, txt_processors = load_preprocess(preprocess_cfg) |
|
else: |
|
vis_processors, txt_processors = None, None |
|
logging.info( |
|
f"""No default preprocess for model {name} ({model_type}). |
|
This can happen if the model is not finetuned on downstream datasets, |
|
or it is not intended for direct use without finetuning. |
|
""" |
|
) |
|
|
|
if device == "cpu" or device == torch.device("cpu"): |
|
model = model.float() |
|
|
|
return model.to(device), vis_processors, txt_processors |
|
|
|
|
|
class BLIP2Decoder(nn.Module): |
|
def __init__(self, llm_name): |
|
super(BLIP2Decoder, self).__init__() |
|
|
|
self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" |
|
if llm_name not in ['pretrain_opt2.7b', 'caption_coco_opt2.7b', |
|
'pretrain_opt6.7b', 'caption_coco_opt6.7b']: |
|
raise ValueError(f"{llm_name} is not support yet") |
|
model_type = llm_name |
|
model, vis, _ = load_model_and_preprocess(name="blip2_opt_det", |
|
model_type=model_type, |
|
is_eval=True, device=self.device) |
|
self.model = model |
|
self.vis_processors = vis |
|
self.freeze_layers() |
|
|
|
def freeze_layers(self): |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
|