THUDM-Space commited on
Commit
595da78
·
verified ·
1 Parent(s): 62f99e7

Update modeling_glm.py

Browse files
Files changed (1) hide show
  1. modeling_glm.py +27 -12
modeling_glm.py CHANGED
@@ -417,7 +417,7 @@ class GlmSdpaAttention(GlmAttention):
417
  )
418
 
419
  bsz, q_len, _ = hidden_states.size()
420
-
421
  query_states = self.q_proj(hidden_states)
422
  key_states = self.k_proj(hidden_states)
423
  value_states = self.v_proj(hidden_states)
@@ -425,7 +425,7 @@ class GlmSdpaAttention(GlmAttention):
425
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
426
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
427
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
428
-
429
  cos, sin = position_embeddings
430
  query_states, key_states = apply_rotary_pos_emb(
431
  query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
@@ -763,21 +763,36 @@ class GlmModel(GlmPreTrainedModel):
763
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
764
  inputs_embeds = self.embed_tokens(input_ids)
765
  new_input_embeds = []
766
- boi_token_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
767
- if is_empty(images):
768
- images = torch.zeros([1, 3, 672, 672]).to(input_ids.device)
769
- images_features = self.vision(images).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
 
 
 
 
 
 
770
  image_count = 0
771
  for i in range(len(input_ids)):
772
  input_id = input_ids[i].tolist()
773
- if boi_token_flags[i]:
774
  boi_token_pos = input_id.index(self.config.boi_token_id)
775
  assert boi_token_pos >= 0, "begin_of_image not found!"
776
  num_image_padding_tokens = input_id.count(self.config.boi_token_id)
777
- assert num_image_padding_tokens == images_features[image_count].shape[0], f"Wrong image padding token number: {num_image_padding_tokens}"
778
- new_input_embeds.append(torch.cat(
779
- (inputs_embeds[i, :boi_token_pos], images_features[image_count],
780
- inputs_embeds[i, boi_token_pos + num_image_padding_tokens:])))
 
 
 
 
 
 
 
 
781
  image_count += 1
782
  else:
783
  new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
@@ -1316,4 +1331,4 @@ __all__ = [
1316
  "GlmModel",
1317
  "GlmForCausalLM",
1318
  "GlmForSequenceClassification",
1319
- ]
 
417
  )
418
 
419
  bsz, q_len, _ = hidden_states.size()
420
+
421
  query_states = self.q_proj(hidden_states)
422
  key_states = self.k_proj(hidden_states)
423
  value_states = self.v_proj(hidden_states)
 
425
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
426
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
427
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
428
+
429
  cos, sin = position_embeddings
430
  query_states, key_states = apply_rotary_pos_emb(
431
  query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
 
763
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
764
  inputs_embeds = self.embed_tokens(input_ids)
765
  new_input_embeds = []
766
+ multi_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
767
+ images_features = None
768
+ if not is_empty(images) and images.bool().any():
769
+ imgs = list()
770
+ for i in range(len(multi_flags)):
771
+ if multi_flags[i]:
772
+ imgs.append(images[i])
773
+ imgs = torch.stack(imgs, dim=0)
774
+ else:
775
+ imgs = torch.unsqueeze(images[0], 0)
776
+ images_features = self.vision(imgs).to(inputs_embeds.dtype)
777
  image_count = 0
778
  for i in range(len(input_ids)):
779
  input_id = input_ids[i].tolist()
780
+ if multi_flags[i]:
781
  boi_token_pos = input_id.index(self.config.boi_token_id)
782
  assert boi_token_pos >= 0, "begin_of_image not found!"
783
  num_image_padding_tokens = input_id.count(self.config.boi_token_id)
784
+ assert (
785
+ num_image_padding_tokens == images_features[image_count].shape[0]
786
+ ), f"Wrong image padding token number: {num_image_padding_tokens}"
787
+ new_input_embeds.append(
788
+ torch.cat(
789
+ (
790
+ inputs_embeds[i, :boi_token_pos],
791
+ images_features[image_count].to(inputs_embeds.device),
792
+ inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
793
+ )
794
+ )
795
+ )
796
  image_count += 1
797
  else:
798
  new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
 
1331
  "GlmModel",
1332
  "GlmForCausalLM",
1333
  "GlmForSequenceClassification",
1334
+ ]