Spaces:
Runtime error
Runtime error
import contextlib | |
import logging | |
import torch | |
import torch.nn as nn | |
from lavis.common.registry import registry | |
from lavis.models import Blip2OPT, load_preprocess | |
from omegaconf import OmegaConf | |
class Blip2OPTDet(Blip2OPT): | |
def __init__( | |
self, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.opt_tokenizer.add_special_tokens({"mask_token": "<mask>"}) | |
def maybe_autocast(self, dtype=torch.float16): | |
# if on cpu, don't use autocast | |
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16 | |
enable_autocast = self.device != torch.device("cpu") | |
if enable_autocast: | |
return torch.cuda.amp.autocast(dtype=dtype) | |
else: | |
return contextlib.nullcontext() | |
def forward(self, samples, | |
use_nucleus_sampling=False, | |
num_beams=5, | |
max_length=30, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
length_penalty=1.0, | |
num_captions=1, | |
temperature=1, | |
task_button=None): | |
image = samples["image"] | |
with self.maybe_autocast(): | |
image_embeds = self.ln_vision(self.visual_encoder(image)) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
image.device | |
) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_opt = self.opt_proj(query_output.last_hidden_state) | |
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device) | |
self.opt_tokenizer.padding_side = "right" | |
if "text_input" in samples.keys(): | |
# text = [t + "\n" for t in samples["text_input"]] | |
text = [t for t in samples["text_input"]] | |
opt_tokens = self.opt_tokenizer( | |
text, | |
return_tensors="pt", | |
padding="longest", | |
).to(image.device) | |
input_ids = opt_tokens.input_ids | |
attention_mask = opt_tokens.attention_mask | |
output_text = text | |
elif "input_ids" in samples.keys(): | |
input_ids = samples["input_ids"] | |
attention_mask = samples["attention_mask"] | |
output_text = [] | |
else: | |
assert "prompt" in samples.keys() | |
prompt = samples["prompt"] | |
assert len(prompt) == image.size(0) | |
opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt", padding=True).to( | |
image.device | |
) | |
input_ids = opt_tokens.input_ids | |
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) | |
if use_nucleus_sampling: | |
query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0) | |
num_beams = 1 | |
else: | |
query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0) | |
with self.maybe_autocast(): | |
outputs = self.opt_model.generate( | |
input_ids=input_ids, | |
query_embeds=query_embeds, | |
attention_mask=attention_mask, | |
do_sample=use_nucleus_sampling, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_new_tokens=max_length, | |
min_length=min_length, | |
eos_token_id=self.eos_token_id, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
) | |
prompt_length = opt_tokens.input_ids.shape[1] | |
output_text = self.opt_tokenizer.batch_decode( | |
outputs[:, prompt_length:], skip_special_tokens=True | |
) | |
output_text = [text.strip() for text in output_text] | |
if task_button == 'Question Answering' or task_button == "Captioning": | |
output_text_input = [prompt[0] + ' ' + output_text[0]] | |
opt_tokens = self.opt_tokenizer( | |
output_text_input, | |
return_tensors="pt", | |
padding="longest", | |
).to(image.device) | |
input_ids = opt_tokens.input_ids | |
attention_mask = opt_tokens.attention_mask | |
inputs_embeds = self.opt_model.model.decoder.embed_tokens(input_ids) | |
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) | |
attention_mask = torch.cat([atts_opt, attention_mask], dim=1) | |
with self.maybe_autocast(): | |
outputs = self.opt_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
return_dict=True, | |
output_hidden_states=True | |
) | |
n_queries = query_tokens.shape[1] | |
out_logits = outputs['logits'][:, n_queries:] | |
out_hidden = outputs['hidden_states'][-1][:, n_queries:] | |
return out_logits, out_hidden, input_ids, output_text | |
def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): | |
model_cls = registry.get_model_class(name) | |
# load model | |
model = model_cls.from_pretrained(model_type=model_type) | |
if is_eval: | |
model.eval() | |
# load preprocess | |
cfg = OmegaConf.load(model_cls.default_config_path(model_type)) | |
if cfg is not None: | |
preprocess_cfg = cfg.preprocess | |
vis_processors, txt_processors = load_preprocess(preprocess_cfg) | |
else: | |
vis_processors, txt_processors = None, None | |
logging.info( | |
f"""No default preprocess for model {name} ({model_type}). | |
This can happen if the model is not finetuned on downstream datasets, | |
or it is not intended for direct use without finetuning. | |
""" | |
) | |
if device == "cpu" or device == torch.device("cpu"): | |
model = model.float() | |
return model.to(device), vis_processors, txt_processors | |
class BLIP2Decoder(nn.Module): | |
def __init__(self, llm_name): | |
super(BLIP2Decoder, self).__init__() | |
self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
if llm_name not in ['pretrain_opt2.7b', 'caption_coco_opt2.7b', | |
'pretrain_opt6.7b', 'caption_coco_opt6.7b']: | |
raise ValueError(f"{llm_name} is not support yet") | |
model_type = llm_name | |
model, vis, _ = load_model_and_preprocess(name="blip2_opt_det", | |
model_type=model_type, | |
is_eval=True, device=self.device) | |
self.model = model | |
self.vis_processors = vis | |
self.freeze_layers() | |
def freeze_layers(self): | |
for p in self.model.parameters(): | |
p.requires_grad = False | |