Update modeling_glm.py
Browse files- modeling_glm.py +27 -12
modeling_glm.py
CHANGED
@@ -417,7 +417,7 @@ class GlmSdpaAttention(GlmAttention):
|
|
417 |
)
|
418 |
|
419 |
bsz, q_len, _ = hidden_states.size()
|
420 |
-
|
421 |
query_states = self.q_proj(hidden_states)
|
422 |
key_states = self.k_proj(hidden_states)
|
423 |
value_states = self.v_proj(hidden_states)
|
@@ -425,7 +425,7 @@ class GlmSdpaAttention(GlmAttention):
|
|
425 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
426 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
427 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
428 |
-
|
429 |
cos, sin = position_embeddings
|
430 |
query_states, key_states = apply_rotary_pos_emb(
|
431 |
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
|
@@ -763,21 +763,36 @@ class GlmModel(GlmPreTrainedModel):
|
|
763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
764 |
inputs_embeds = self.embed_tokens(input_ids)
|
765 |
new_input_embeds = []
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
image_count = 0
|
771 |
for i in range(len(input_ids)):
|
772 |
input_id = input_ids[i].tolist()
|
773 |
-
if
|
774 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
775 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
776 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
777 |
-
assert
|
778 |
-
|
779 |
-
|
780 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
image_count += 1
|
782 |
else:
|
783 |
new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
|
@@ -1316,4 +1331,4 @@ __all__ = [
|
|
1316 |
"GlmModel",
|
1317 |
"GlmForCausalLM",
|
1318 |
"GlmForSequenceClassification",
|
1319 |
-
]
|
|
|
417 |
)
|
418 |
|
419 |
bsz, q_len, _ = hidden_states.size()
|
420 |
+
|
421 |
query_states = self.q_proj(hidden_states)
|
422 |
key_states = self.k_proj(hidden_states)
|
423 |
value_states = self.v_proj(hidden_states)
|
|
|
425 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
426 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
427 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
428 |
+
|
429 |
cos, sin = position_embeddings
|
430 |
query_states, key_states = apply_rotary_pos_emb(
|
431 |
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
|
|
|
763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
764 |
inputs_embeds = self.embed_tokens(input_ids)
|
765 |
new_input_embeds = []
|
766 |
+
multi_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
|
767 |
+
images_features = None
|
768 |
+
if not is_empty(images) and images.bool().any():
|
769 |
+
imgs = list()
|
770 |
+
for i in range(len(multi_flags)):
|
771 |
+
if multi_flags[i]:
|
772 |
+
imgs.append(images[i])
|
773 |
+
imgs = torch.stack(imgs, dim=0)
|
774 |
+
else:
|
775 |
+
imgs = torch.unsqueeze(images[0], 0)
|
776 |
+
images_features = self.vision(imgs).to(inputs_embeds.dtype)
|
777 |
image_count = 0
|
778 |
for i in range(len(input_ids)):
|
779 |
input_id = input_ids[i].tolist()
|
780 |
+
if multi_flags[i]:
|
781 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
782 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
783 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
784 |
+
assert (
|
785 |
+
num_image_padding_tokens == images_features[image_count].shape[0]
|
786 |
+
), f"Wrong image padding token number: {num_image_padding_tokens}"
|
787 |
+
new_input_embeds.append(
|
788 |
+
torch.cat(
|
789 |
+
(
|
790 |
+
inputs_embeds[i, :boi_token_pos],
|
791 |
+
images_features[image_count].to(inputs_embeds.device),
|
792 |
+
inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
|
793 |
+
)
|
794 |
+
)
|
795 |
+
)
|
796 |
image_count += 1
|
797 |
else:
|
798 |
new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
|
|
|
1331 |
"GlmModel",
|
1332 |
"GlmForCausalLM",
|
1333 |
"GlmForSequenceClassification",
|
1334 |
+
]
|