|
import logging |
|
import torch |
|
import torch.nn as nn |
|
from contextlib import suppress |
|
from einops import rearrange |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, PreTrainedModel |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from .eva_vit import create_eva_vit_g |
|
from .pooler import Pooler |
|
|
|
|
|
def get_autocast(precision, cache_enabled=True): |
|
if precision == "amp": |
|
return lambda: torch.cuda.amp.autocast(cache_enabled=cache_enabled) |
|
elif precision == "amp_bfloat16" or precision == "amp_bf16" or precision == 'bf16': |
|
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled) |
|
elif precision == 'fp16': |
|
return lambda: torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=cache_enabled) |
|
elif precision == 'fp32': |
|
return suppress |
|
else: |
|
raise ValueError('not supported precision: {}'.format(precision)) |
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
def init_vision_encoder(model_name, |
|
img_size, |
|
drop_path_rate, |
|
use_grad_checkpoint): |
|
if model_name == "eva_clip_g": |
|
visual_encoder = create_eva_vit_g( |
|
img_size, drop_path_rate, use_grad_checkpoint) |
|
else: |
|
raise ValueError() |
|
|
|
ln_vision = LayerNorm(visual_encoder.num_features) |
|
return visual_encoder, ln_vision |
|
|
|
class ImageProcessor: |
|
def __init__(self, image_size=364, mean=None, std=None): |
|
if mean is None: |
|
self.mean = mean = (0.48145466, 0.4578275, 0.40821073) |
|
if std is None: |
|
self.std = std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
self.normalize = transforms.Normalize(mean, std) |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC |
|
), |
|
transforms.ToTensor(), |
|
self.normalize, |
|
] |
|
) |
|
|
|
def __call__(self, item): |
|
return self.transform(item) |
|
|
|
class InfMLLM(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
vit_model = config.vit_model |
|
img_size = config.image_size |
|
lm_model = config.lm_model |
|
lm_tokenizer = config.lm_tokenizer |
|
precision = config.precision |
|
pool_out_size = config.pool_out_size |
|
self.img_processor = ImageProcessor(image_size=img_size) |
|
|
|
self.visual_encoder, self.ln_vision = init_vision_encoder( |
|
vit_model, img_size, drop_path_rate=0.0, use_grad_checkpoint=False) |
|
|
|
self.lm_tokenizer = LlamaTokenizer.from_pretrained(lm_tokenizer, use_fast=False, trust_remote_code=True) |
|
self.lm_tokenizer.pad_token = self.lm_tokenizer.unk_token |
|
self.lm_model = LlamaForCausalLM.from_pretrained(lm_model, trust_remote_code=True, torch_dtype='auto') |
|
|
|
self.pooler = Pooler(dim_in=self.visual_encoder.num_features, |
|
dim_out=self.lm_model.config.hidden_size, |
|
pool_out_size=pool_out_size) |
|
self.llama_proj = nn.Identity() |
|
|
|
self.precision = precision |
|
self._apply_lemmatizer = config.apply_lemmatizer if hasattr(config, 'apply_lemmatizer') else False |
|
self._lemmatizer = None |
|
|
|
def prompt_wrap(self, img_embeds, atts_img, prompts): |
|
assert len(img_embeds) == len(atts_img) == len(prompts) |
|
|
|
bos = torch.ones([1, 1], dtype=torch.long, device=img_embeds.device) * self.lm_tokenizer.bos_token_id |
|
bos_embeds = self.lm_model.get_input_embeddings()(bos) |
|
|
|
emb_lists = [] |
|
image_mask = [] |
|
for each_img_embed, each_prompt in zip(img_embeds, prompts): |
|
assert '<ImageHere>' in each_prompt |
|
p_before, p_after = each_prompt.split('<ImageHere>') |
|
|
|
p_before_tokens = self.lm_tokenizer( |
|
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
|
p_after_tokens = self.lm_tokenizer( |
|
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
|
|
|
p_before_embed = self.lm_model.get_input_embeddings()(p_before_tokens.input_ids.long()) |
|
p_after_embed = self.lm_model.get_input_embeddings()(p_after_tokens.input_ids.long()) |
|
|
|
wrapped_emb = torch.cat([bos_embeds, p_before_embed, each_img_embed[None], p_after_embed], dim=1) |
|
emb_lists.append(wrapped_emb) |
|
|
|
image_mask.append( torch.tensor([0] * wrapped_emb.size(1)) ) |
|
image_mask[-1][range(bos_embeds.size(1) + p_before_embed.size(1), |
|
bos_embeds.size(1) + p_before_embed.size(1) + len(each_img_embed))] = 1 |
|
assert image_mask[-1].sum() == each_img_embed.size(0) |
|
|
|
emb_lens = [emb.shape[1] for emb in emb_lists] |
|
pad_emb = self.lm_model.get_input_embeddings()(torch.tensor(self.lm_tokenizer.pad_token_id, device=img_embeds.device)) |
|
|
|
assert not self.training |
|
|
|
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() |
|
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device) |
|
wrapped_image_masks = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device) |
|
for i, emb in enumerate(emb_lists): |
|
wrapped_embs[i, -emb_lens[i]:] = emb |
|
wrapped_atts[i, -emb_lens[i]:] = 1 |
|
wrapped_image_masks[i, -emb_lens[i]:] = image_mask[i] |
|
return wrapped_embs, wrapped_atts, wrapped_image_masks |
|
|
|
@torch.no_grad() |
|
def forward_image_feature(self, image): |
|
autocast = get_autocast(self.precision, cache_enabled=True) |
|
with autocast(): |
|
if image.ndim == 4: |
|
image = image.unsqueeze(1).unsqueeze(1) |
|
assert image.ndim == 6 |
|
|
|
b, t, f = image.shape[:3] |
|
assert t == 1 and f == 1 |
|
image = rearrange(image, "b t f c h w -> (b t f) c h w") |
|
|
|
image_embeds = self.ln_vision(self.visual_encoder(image)) |
|
|
|
image_embeds = rearrange(image_embeds, "(b t f) L D -> b t f L D", t=t, f=f) |
|
query_output= self.pooler(image_embeds) |
|
query_output = query_output.squeeze(1) |
|
embeds_img = self.llama_proj(query_output) |
|
|
|
return embeds_img |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
samples, |
|
use_nucleus_sampling=False, |
|
num_beams=5, |
|
max_length=30, |
|
min_length=1, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
length_penalty=1.0, |
|
num_captions=1, |
|
temperature=1, |
|
): |
|
autocast = get_autocast(self.precision, cache_enabled=True) |
|
with autocast(): |
|
image = samples["image"] |
|
embeds_img = self.forward_image_feature(image) |
|
atts_img = torch.ones(embeds_img.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
prompts = samples["prompts"] |
|
assert isinstance(prompts, (tuple, list)) |
|
|
|
|
|
inputs_embeds, attention_mask, masks_img = self.prompt_wrap(embeds_img, atts_img, prompts) |
|
|
|
model_args = dict( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
do_sample=use_nucleus_sampling, |
|
top_p=top_p, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
max_length=max_length, |
|
min_length=min_length, |
|
eos_token_id=self.lm_tokenizer.eos_token_id, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
num_return_sequences=num_captions, |
|
) |
|
outputs = self.lm_model.generate(**model_args) |
|
|
|
output_text = self.lm_tokenizer.batch_decode( |
|
outputs, skip_special_tokens=True |
|
) |
|
|
|
output_text = [text.strip() for text in output_text] |
|
|
|
return output_text |
|
|
|
@torch.no_grad() |
|
def predict_answers( |
|
self, |
|
samples, |
|
num_beams=5, |
|
max_len=10, |
|
min_len=1, |
|
length_penalty=0, |
|
): |
|
|
|
autocast = get_autocast(self.precision, cache_enabled=True) |
|
with autocast(): |
|
image = samples["image"] |
|
embeds_img = self.forward_image_feature(image) |
|
atts_img = torch.ones(embeds_img.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
prompts = samples["prompts"] |
|
assert isinstance(prompts, (tuple, list)) |
|
|
|
inputs_embeds, attention_mask, masks_img = self.prompt_wrap(embeds_img, atts_img, prompts) |
|
|
|
model_args = dict( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
do_sample=False, |
|
num_beams=num_beams, |
|
max_new_tokens=max_len, |
|
min_length=min_len, |
|
eos_token_id=self.lm_tokenizer.eos_token_id, |
|
length_penalty=length_penalty |
|
) |
|
|
|
outputs = self.lm_model.generate(**model_args) |
|
output_text = self.lm_tokenizer.batch_decode( |
|
outputs, skip_special_tokens=True |
|
) |
|
output_text = [text.strip() for text in output_text] |
|
|
|
if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): |
|
output_text = self._lemmatize(output_text) |
|
|
|
return output_text |
|
|
|
def _lemmatize(self, answers): |
|
def apply(answer): |
|
doc = self.lemmatizer(answer) |
|
|
|
words = [] |
|
for token in doc: |
|
if token.pos_ in ["NOUN", "VERB"]: |
|
words.append(token.lemma_) |
|
else: |
|
words.append(token.text) |
|
answer = " ".join(words) |
|
|
|
return answer |
|
|
|
return [apply(answer) for answer in answers] |
|
|
|
@property |
|
def lemmatizer(self): |
|
if self._lemmatizer is None: |
|
try: |
|
import spacy |
|
|
|
self._lemmatizer = spacy.load("en_core_web_sm") |
|
except ImportError: |
|
logging.error( |
|
""" |
|
Please install spacy and en_core_web_sm model to apply lemmatization. |
|
python -m spacy download en_core_web_sm |
|
OR |
|
import spacy.cli |
|
spacy.cli.download("en_core_web_sm") |
|
""" |
|
) |
|
exit(1) |
|
|
|
return self._lemmatizer |
|
|