File size: 4,900 Bytes
d8cc680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
AutoModelForCausalLM,
AutoModel,
SiglipImageProcessor,
)
from .configuration_llamavision import LlamavisionConfig
class ProjectionModule(nn.Module):
def __init__(self, mm_hidden_size=1152, hidden_size=4096):
super(ProjectionModule, self).__init__()
# Directly set up the sequential model
self.model = nn.Sequential(
nn.Linear(mm_hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x):
return self.model(x)
class Llamavision(PreTrainedModel):
config_class = LlamavisionConfig
def __init__(self, config):
super().__init__(config)
self.text_model = AutoModelForCausalLM.from_config(config.text_config)
self.vision_model = AutoModel.from_config(config.vision_config)
self.processor = SiglipImageProcessor()
self.mm_projector = ProjectionModule()
@property
def device(self):
return self.text_model.device
def tokenizer_image_token(
self, prompt, tokenizer, image_token_index=-200, return_tensors=None
):
prompt_chunks = [
tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
return torch.tensor(input_ids, dtype=torch.long)
def process_tensors(self, input_ids, image_features, embedding_layer):
# Find the index of -200 in input_ids
split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0]
# Split the input_ids at the index found, excluding -200
input_ids_1 = input_ids[:, :split_index]
input_ids_2 = input_ids[:, split_index + 1 :]
# Convert input_ids to embeddings
embeddings_1 = embedding_layer(input_ids_1)
embeddings_2 = embedding_layer(input_ids_2)
device = image_features.device
token_embeddings_part1 = embeddings_1.to(device)
token_embeddings_part2 = embeddings_2.to(device)
# Concatenate the token embeddings and image features
concatenated_embeddings = torch.cat(
[token_embeddings_part1, image_features, token_embeddings_part2], dim=1
)
# Create the corrected attention mask
attention_mask = torch.ones(
concatenated_embeddings.shape[:2], dtype=torch.long, device=device
)
return concatenated_embeddings, attention_mask
def answer_question(self, image, question, tokenizer, **kwargs):
question = "<image>" + question
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
input_ids = (
self.tokenizer_image_token(prompt, tokenizer, -200, return_tensors="pt")
.unsqueeze(0)
.to(self.device)
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
with torch.inference_mode():
image_inputs = self.processor(
images=[image],
return_tensors="pt",
do_resize=True,
size={"height": 384, "width": 384},
)
image_inputs = image_inputs["pixel_values"].to(
device=self.device, dtype=self.dtype
)
image_forward_outs = self.vision_model(
image_inputs,
output_hidden_states=True,
)
image_features = image_forward_outs.hidden_states[-2]
projected_embeddings = self.mm_projector(image_features).to(self.device)
embedding_layer = self.text_model.get_input_embeddings()
# text_embeddings = embedding_layer(input_ids)
new_embeds, attn_mask = self.process_tensors(
input_ids, projected_embeddings, embedding_layer
)
attn_mask = attn_mask.to(self.device)
new_embeds = new_embeds.to(self.device)
answer = self.text_model.generate(
inputs_embeds=new_embeds,
attention_mask=attn_mask,
eos_token_id=terminators,
temperature=0.2,
do_sample=True,
**kwargs,
)[0]
return answer
|