zwt123home123 commited on
Commit
aa5e2c3
·
verified ·
1 Parent(s): 1532d66

Update modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +35 -6
modeling_internvl_chat.py CHANGED
@@ -24,6 +24,8 @@ from .modeling_internlm2 import InternLM2ForCausalLM
24
 
25
  logger = logging.get_logger(__name__)
26
 
 
 
27
 
28
  def version_cmp(v1, v2, op='eq'):
29
  import operator
@@ -53,6 +55,8 @@ class InternVLChatModel(PreTrainedModel):
53
  self.downsample_ratio = config.downsample_ratio
54
  self.ps_version = config.ps_version
55
  use_flash_attn = use_flash_attn if has_flash_attn else False
 
 
56
  config.vision_config.use_flash_attn = True if use_flash_attn else False
57
  config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
58
 
@@ -182,6 +186,7 @@ class InternVLChatModel(PreTrainedModel):
182
  return x
183
 
184
  def extract_feature(self, pixel_values):
 
185
  if self.select_layer == -1:
186
  vit_embeds = self.vision_model(
187
  pixel_values=pixel_values,
@@ -193,9 +198,11 @@ class InternVLChatModel(PreTrainedModel):
193
  output_hidden_states=True,
194
  return_dict=True).hidden_states[self.select_layer]
195
  vit_embeds = vit_embeds[:, 1:, :]
196
-
197
  h = w = int(vit_embeds.shape[1] ** 0.5)
 
198
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
 
199
  vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
200
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
201
  vit_embeds = self.mlp1(vit_embeds)
@@ -233,13 +240,14 @@ class InternVLChatModel(PreTrainedModel):
233
  image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
234
  query = query.replace('<image>', image_tokens, 1)
235
  queries.append(query)
236
-
237
  tokenizer.padding_side = 'left'
238
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
239
  input_ids = model_inputs['input_ids'].to(self.device)
240
  attention_mask = model_inputs['attention_mask'].to(self.device)
241
  eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
242
  generation_config['eos_token_id'] = eos_token_id
 
243
  generation_output = self.generate(
244
  pixel_values=pixel_values,
245
  input_ids=input_ids,
@@ -317,14 +325,32 @@ class InternVLChatModel(PreTrainedModel):
317
  output_hidden_states: Optional[bool] = None,
318
  **generate_kwargs,
319
  ) -> torch.LongTensor:
320
-
321
  assert self.img_context_token_id is not None
322
  if pixel_values is not None:
 
323
  if visual_features is not None:
324
  vit_embeds = visual_features
325
  else:
326
- vit_embeds = self.extract_feature(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
 
328
  B, N, C = input_embeds.shape
329
  input_embeds = input_embeds.reshape(B * N, C)
330
 
@@ -332,11 +358,14 @@ class InternVLChatModel(PreTrainedModel):
332
  selected = (input_ids == self.img_context_token_id)
333
  assert selected.sum() != 0
334
  input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
335
-
 
 
 
336
  input_embeds = input_embeds.reshape(B, N, C)
337
  else:
338
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
339
-
340
  outputs = self.language_model.generate(
341
  inputs_embeds=input_embeds,
342
  attention_mask=attention_mask,
 
24
 
25
  logger = logging.get_logger(__name__)
26
 
27
+ import os
28
+ image_token_num = 0
29
 
30
  def version_cmp(v1, v2, op='eq'):
31
  import operator
 
55
  self.downsample_ratio = config.downsample_ratio
56
  self.ps_version = config.ps_version
57
  use_flash_attn = use_flash_attn if has_flash_attn else False
58
+ #use_flash_attn = True
59
+ #use_flash_attn = False
60
  config.vision_config.use_flash_attn = True if use_flash_attn else False
61
  config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
62
 
 
186
  return x
187
 
188
  def extract_feature(self, pixel_values):
189
+
190
  if self.select_layer == -1:
191
  vit_embeds = self.vision_model(
192
  pixel_values=pixel_values,
 
198
  output_hidden_states=True,
199
  return_dict=True).hidden_states[self.select_layer]
200
  vit_embeds = vit_embeds[:, 1:, :]
201
+
202
  h = w = int(vit_embeds.shape[1] ** 0.5)
203
+ os.environ['IMAGE_H'] = str(h)
204
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
205
+ # import pdb; pdb.set_trace()
206
  vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
207
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
208
  vit_embeds = self.mlp1(vit_embeds)
 
240
  image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
241
  query = query.replace('<image>', image_tokens, 1)
242
  queries.append(query)
243
+
244
  tokenizer.padding_side = 'left'
245
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
246
  input_ids = model_inputs['input_ids'].to(self.device)
247
  attention_mask = model_inputs['attention_mask'].to(self.device)
248
  eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
249
  generation_config['eos_token_id'] = eos_token_id
250
+
251
  generation_output = self.generate(
252
  pixel_values=pixel_values,
253
  input_ids=input_ids,
 
325
  output_hidden_states: Optional[bool] = None,
326
  **generate_kwargs,
327
  ) -> torch.LongTensor:
328
+
329
  assert self.img_context_token_id is not None
330
  if pixel_values is not None:
331
+
332
  if visual_features is not None:
333
  vit_embeds = visual_features
334
  else:
335
+ #vit_embeds = self.extract_feature(pixel_values)
336
+ # Assuming pixel_values is already defined
337
+ batch_size = 10
338
+ num_samples = pixel_values.size(0) # Total number of samples
339
+ vit_embeds_list = []
340
+
341
+ # Loop through the batches
342
+ for start_idx in range(0, num_samples, batch_size):
343
+ end_idx = min(start_idx + batch_size, num_samples) # Ensure the end index doesn't exceed the size
344
+ batch = pixel_values[start_idx:end_idx] # Slice the batch
345
+ vit_embeds_batch = self.extract_feature(batch) # Process the batch
346
+ vit_embeds_list.append(vit_embeds_batch) # Collect the results
347
+
348
+ # Concatenate the embeddings if requiimport pdb; pdb.set_trace()red
349
+ vit_embeds = torch.cat(vit_embeds_list, dim=0)
350
+
351
+
352
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
353
+
354
  B, N, C = input_embeds.shape
355
  input_embeds = input_embeds.reshape(B * N, C)
356
 
 
358
  selected = (input_ids == self.img_context_token_id)
359
  assert selected.sum() != 0
360
  input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
361
+
362
+ image_token_num = int(vit_embeds.shape[0] * vit_embeds.shape[1]/B)
363
+ os.environ['IMAGE_TOKEN_NUM'] = str(image_token_num)
364
+
365
  input_embeds = input_embeds.reshape(B, N, C)
366
  else:
367
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
368
+
369
  outputs = self.language_model.generate(
370
  inputs_embeds=input_embeds,
371
  attention_mask=attention_mask,