Update modeling_internvl_chat.py
Browse files- modeling_internvl_chat.py +35 -6
modeling_internvl_chat.py
CHANGED
@@ -24,6 +24,8 @@ from .modeling_internlm2 import InternLM2ForCausalLM
|
|
24 |
|
25 |
logger = logging.get_logger(__name__)
|
26 |
|
|
|
|
|
27 |
|
28 |
def version_cmp(v1, v2, op='eq'):
|
29 |
import operator
|
@@ -53,6 +55,8 @@ class InternVLChatModel(PreTrainedModel):
|
|
53 |
self.downsample_ratio = config.downsample_ratio
|
54 |
self.ps_version = config.ps_version
|
55 |
use_flash_attn = use_flash_attn if has_flash_attn else False
|
|
|
|
|
56 |
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
57 |
config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
|
58 |
|
@@ -182,6 +186,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
182 |
return x
|
183 |
|
184 |
def extract_feature(self, pixel_values):
|
|
|
185 |
if self.select_layer == -1:
|
186 |
vit_embeds = self.vision_model(
|
187 |
pixel_values=pixel_values,
|
@@ -193,9 +198,11 @@ class InternVLChatModel(PreTrainedModel):
|
|
193 |
output_hidden_states=True,
|
194 |
return_dict=True).hidden_states[self.select_layer]
|
195 |
vit_embeds = vit_embeds[:, 1:, :]
|
196 |
-
|
197 |
h = w = int(vit_embeds.shape[1] ** 0.5)
|
|
|
198 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
|
|
199 |
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
200 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
201 |
vit_embeds = self.mlp1(vit_embeds)
|
@@ -233,13 +240,14 @@ class InternVLChatModel(PreTrainedModel):
|
|
233 |
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
234 |
query = query.replace('<image>', image_tokens, 1)
|
235 |
queries.append(query)
|
236 |
-
|
237 |
tokenizer.padding_side = 'left'
|
238 |
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
239 |
input_ids = model_inputs['input_ids'].to(self.device)
|
240 |
attention_mask = model_inputs['attention_mask'].to(self.device)
|
241 |
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
242 |
generation_config['eos_token_id'] = eos_token_id
|
|
|
243 |
generation_output = self.generate(
|
244 |
pixel_values=pixel_values,
|
245 |
input_ids=input_ids,
|
@@ -317,14 +325,32 @@ class InternVLChatModel(PreTrainedModel):
|
|
317 |
output_hidden_states: Optional[bool] = None,
|
318 |
**generate_kwargs,
|
319 |
) -> torch.LongTensor:
|
320 |
-
|
321 |
assert self.img_context_token_id is not None
|
322 |
if pixel_values is not None:
|
|
|
323 |
if visual_features is not None:
|
324 |
vit_embeds = visual_features
|
325 |
else:
|
326 |
-
vit_embeds = self.extract_feature(pixel_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
|
328 |
B, N, C = input_embeds.shape
|
329 |
input_embeds = input_embeds.reshape(B * N, C)
|
330 |
|
@@ -332,11 +358,14 @@ class InternVLChatModel(PreTrainedModel):
|
|
332 |
selected = (input_ids == self.img_context_token_id)
|
333 |
assert selected.sum() != 0
|
334 |
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
335 |
-
|
|
|
|
|
|
|
336 |
input_embeds = input_embeds.reshape(B, N, C)
|
337 |
else:
|
338 |
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
339 |
-
|
340 |
outputs = self.language_model.generate(
|
341 |
inputs_embeds=input_embeds,
|
342 |
attention_mask=attention_mask,
|
|
|
24 |
|
25 |
logger = logging.get_logger(__name__)
|
26 |
|
27 |
+
import os
|
28 |
+
image_token_num = 0
|
29 |
|
30 |
def version_cmp(v1, v2, op='eq'):
|
31 |
import operator
|
|
|
55 |
self.downsample_ratio = config.downsample_ratio
|
56 |
self.ps_version = config.ps_version
|
57 |
use_flash_attn = use_flash_attn if has_flash_attn else False
|
58 |
+
#use_flash_attn = True
|
59 |
+
#use_flash_attn = False
|
60 |
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
61 |
config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
|
62 |
|
|
|
186 |
return x
|
187 |
|
188 |
def extract_feature(self, pixel_values):
|
189 |
+
|
190 |
if self.select_layer == -1:
|
191 |
vit_embeds = self.vision_model(
|
192 |
pixel_values=pixel_values,
|
|
|
198 |
output_hidden_states=True,
|
199 |
return_dict=True).hidden_states[self.select_layer]
|
200 |
vit_embeds = vit_embeds[:, 1:, :]
|
201 |
+
|
202 |
h = w = int(vit_embeds.shape[1] ** 0.5)
|
203 |
+
os.environ['IMAGE_H'] = str(h)
|
204 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
205 |
+
# import pdb; pdb.set_trace()
|
206 |
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
207 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
208 |
vit_embeds = self.mlp1(vit_embeds)
|
|
|
240 |
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
241 |
query = query.replace('<image>', image_tokens, 1)
|
242 |
queries.append(query)
|
243 |
+
|
244 |
tokenizer.padding_side = 'left'
|
245 |
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
246 |
input_ids = model_inputs['input_ids'].to(self.device)
|
247 |
attention_mask = model_inputs['attention_mask'].to(self.device)
|
248 |
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
249 |
generation_config['eos_token_id'] = eos_token_id
|
250 |
+
|
251 |
generation_output = self.generate(
|
252 |
pixel_values=pixel_values,
|
253 |
input_ids=input_ids,
|
|
|
325 |
output_hidden_states: Optional[bool] = None,
|
326 |
**generate_kwargs,
|
327 |
) -> torch.LongTensor:
|
328 |
+
|
329 |
assert self.img_context_token_id is not None
|
330 |
if pixel_values is not None:
|
331 |
+
|
332 |
if visual_features is not None:
|
333 |
vit_embeds = visual_features
|
334 |
else:
|
335 |
+
#vit_embeds = self.extract_feature(pixel_values)
|
336 |
+
# Assuming pixel_values is already defined
|
337 |
+
batch_size = 10
|
338 |
+
num_samples = pixel_values.size(0) # Total number of samples
|
339 |
+
vit_embeds_list = []
|
340 |
+
|
341 |
+
# Loop through the batches
|
342 |
+
for start_idx in range(0, num_samples, batch_size):
|
343 |
+
end_idx = min(start_idx + batch_size, num_samples) # Ensure the end index doesn't exceed the size
|
344 |
+
batch = pixel_values[start_idx:end_idx] # Slice the batch
|
345 |
+
vit_embeds_batch = self.extract_feature(batch) # Process the batch
|
346 |
+
vit_embeds_list.append(vit_embeds_batch) # Collect the results
|
347 |
+
|
348 |
+
# Concatenate the embeddings if requiimport pdb; pdb.set_trace()red
|
349 |
+
vit_embeds = torch.cat(vit_embeds_list, dim=0)
|
350 |
+
|
351 |
+
|
352 |
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
353 |
+
|
354 |
B, N, C = input_embeds.shape
|
355 |
input_embeds = input_embeds.reshape(B * N, C)
|
356 |
|
|
|
358 |
selected = (input_ids == self.img_context_token_id)
|
359 |
assert selected.sum() != 0
|
360 |
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
361 |
+
|
362 |
+
image_token_num = int(vit_embeds.shape[0] * vit_embeds.shape[1]/B)
|
363 |
+
os.environ['IMAGE_TOKEN_NUM'] = str(image_token_num)
|
364 |
+
|
365 |
input_embeds = input_embeds.reshape(B, N, C)
|
366 |
else:
|
367 |
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
368 |
+
|
369 |
outputs = self.language_model.generate(
|
370 |
inputs_embeds=input_embeds,
|
371 |
attention_mask=attention_mask,
|