Update modeling_llama3.py
Browse files- modeling_llama3.py +44 -40
modeling_llama3.py
CHANGED
@@ -6,7 +6,7 @@ import torch.utils.checkpoint
|
|
6 |
from torch import nn
|
7 |
|
8 |
import transformers
|
9 |
-
from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel, AutoModelForCausalLM
|
10 |
from transformers.cache_utils import Cache, StaticCache
|
11 |
from transformers.generation import GenerationMixin
|
12 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
@@ -14,7 +14,6 @@ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutpu
|
|
14 |
from transformers.utils import logging
|
15 |
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
|
16 |
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
17 |
-
|
18 |
from .configuration_llama3 import Llama3Config
|
19 |
from .mllama_audio_model import Llama3Embedding
|
20 |
|
@@ -25,27 +24,27 @@ class Llama3PreTrainedModel(MllamaPreTrainedModel):
|
|
25 |
config_class = Llama3Config
|
26 |
base_model_prefix = "model"
|
27 |
|
28 |
-
class Llama3TextModel(
|
29 |
-
config_class =
|
30 |
base_model_prefix = "language_model.model"
|
31 |
|
32 |
-
def __init__(self, config:
|
33 |
super().__init__(config)
|
34 |
-
self.padding_idx = config.
|
35 |
-
self.vocab_size = config.
|
36 |
-
self.embed_tokens = Llama3Embedding(
|
37 |
-
self.cross_attention_layers = config.
|
38 |
|
39 |
layers = []
|
40 |
-
for layer_idx in range(config.
|
41 |
if layer_idx in self.cross_attention_layers:
|
42 |
-
layers.append(MllamaCrossAttentionDecoderLayer(config
|
43 |
else:
|
44 |
-
layers.append(MllamaSelfAttentionDecoderLayer(config
|
45 |
|
46 |
self.layers = nn.ModuleList(layers)
|
47 |
-
self.norm = MllamaTextRMSNorm(config.
|
48 |
-
self.rotary_emb = MllamaRotaryEmbedding(config=config
|
49 |
self.gradient_checkpointing = False
|
50 |
self.post_init()
|
51 |
|
@@ -57,8 +56,8 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
57 |
|
58 |
def forward(
|
59 |
self,
|
60 |
-
input_ids: Optional[torch.LongTensor] = None,
|
61 |
-
audio_features: Optional[torch.Tensor] = None,
|
62 |
attention_mask: Optional[torch.Tensor] = None,
|
63 |
position_ids: Optional[torch.LongTensor] = None,
|
64 |
cross_attention_states: Optional[torch.FloatTensor] = None,
|
@@ -94,15 +93,15 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
94 |
torch.Size([1, 13, 4096])
|
95 |
```
|
96 |
"""
|
97 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.
|
98 |
output_hidden_states = (
|
99 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
100 |
)
|
101 |
-
use_cache = use_cache if use_cache is not None else self.config.
|
102 |
-
return_dict = return_dict if return_dict is not None else self.config.
|
103 |
|
104 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
105 |
-
|
106 |
|
107 |
if self.gradient_checkpointing and self.training and use_cache:
|
108 |
logger.warning_once(
|
@@ -110,8 +109,9 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
110 |
)
|
111 |
use_cache = False
|
112 |
|
113 |
-
if inputs_embeds is None:
|
114 |
-
|
|
|
115 |
|
116 |
hidden_states = inputs_embeds
|
117 |
|
@@ -214,7 +214,7 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
214 |
past_key_values: Cache,
|
215 |
output_attentions: bool,
|
216 |
):
|
217 |
-
if self.config.
|
218 |
if attention_mask is not None and 0.0 in attention_mask:
|
219 |
return attention_mask
|
220 |
return None
|
@@ -258,7 +258,7 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
258 |
)
|
259 |
|
260 |
if (
|
261 |
-
self.config.
|
262 |
and attention_mask is not None
|
263 |
and attention_mask.device.type == "cuda"
|
264 |
and not output_attentions
|
@@ -306,25 +306,26 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
306 |
|
307 |
return causal_mask
|
308 |
|
309 |
-
class Llama3ForCausalLM(
|
310 |
-
config_class =
|
311 |
base_model_prefix = "model"
|
312 |
-
_tied_weights_keys = ["lm_head.weight"]
|
313 |
|
314 |
-
def __init__(self, config):
|
315 |
super().__init__(config)
|
316 |
-
self.
|
317 |
-
self.vocab_size = self.text_config.vocab_size
|
318 |
self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
|
319 |
-
self.lm_head = nn.Linear(
|
320 |
|
321 |
self.post_init()
|
322 |
|
323 |
def get_input_embeddings(self):
|
324 |
-
return self.model.embed_tokens.text_embeddings
|
|
|
325 |
|
326 |
def set_input_embeddings(self, value):
|
327 |
-
self.model.embed_tokens.text_embeddings = value
|
|
|
328 |
|
329 |
def get_output_embeddings(self):
|
330 |
return self.lm_head
|
@@ -340,7 +341,7 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
|
340 |
|
341 |
def forward(
|
342 |
self,
|
343 |
-
input_ids: torch.LongTensor = None,
|
344 |
attention_mask: Optional[torch.Tensor] = None,
|
345 |
position_ids: Optional[torch.LongTensor] = None,
|
346 |
cross_attention_states: Optional[torch.LongTensor] = None,
|
@@ -357,15 +358,15 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
|
357 |
num_logits_to_keep: int = 0,
|
358 |
**loss_kwargs,
|
359 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
360 |
-
output_attentions = output_attentions if output_attentions is not None else self.
|
361 |
output_hidden_states = (
|
362 |
-
output_hidden_states if output_hidden_states is not None else self.
|
363 |
)
|
364 |
-
return_dict = return_dict if return_dict is not None else self.
|
365 |
|
366 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
367 |
outputs = self.model(
|
368 |
-
input_ids=input_ids,
|
369 |
cross_attention_states=cross_attention_states,
|
370 |
attention_mask=attention_mask,
|
371 |
position_ids=position_ids,
|
@@ -402,7 +403,9 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
|
402 |
AutoModelForCausalLM.register(Llama3Config, Llama3ForCausalLM)
|
403 |
transformers.Llama3ForCausalLM = Llama3ForCausalLM
|
404 |
|
405 |
-
class Llama3ForConditionalGeneration(
|
|
|
|
|
406 |
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
|
407 |
|
408 |
def __init__(self, config: Llama3Config):
|
@@ -415,6 +418,7 @@ class Llama3ForConditionalGeneration(Llama3PreTrainedModel, GenerationMixin):
|
|
415 |
|
416 |
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
|
417 |
self.language_model = Llama3ForCausalLM._from_config(config)
|
|
|
418 |
self.multi_modal_projector = nn.Linear(
|
419 |
config.vision_config.vision_output_dim,
|
420 |
config.text_config.hidden_size,
|
|
|
6 |
from torch import nn
|
7 |
|
8 |
import transformers
|
9 |
+
from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, Wav2Vec2BertConfig, AutoModel, AutoModelForCausalLM
|
10 |
from transformers.cache_utils import Cache, StaticCache
|
11 |
from transformers.generation import GenerationMixin
|
12 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
|
14 |
from transformers.utils import logging
|
15 |
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
|
16 |
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
|
|
17 |
from .configuration_llama3 import Llama3Config
|
18 |
from .mllama_audio_model import Llama3Embedding
|
19 |
|
|
|
24 |
config_class = Llama3Config
|
25 |
base_model_prefix = "model"
|
26 |
|
27 |
+
class Llama3TextModel(MllamaPreTrainedModel):
|
28 |
+
config_class = MllamaTextConfig
|
29 |
base_model_prefix = "language_model.model"
|
30 |
|
31 |
+
def __init__(self, config: MllamaTextConfig, audio_config: Wav2Vec2BertConfig):
|
32 |
super().__init__(config)
|
33 |
+
self.padding_idx = config.pad_token_id
|
34 |
+
self.vocab_size = config.vocab_size
|
35 |
+
#self.embed_tokens = Llama3Embedding(audio_config, config)
|
36 |
+
self.cross_attention_layers = config.cross_attention_layers
|
37 |
|
38 |
layers = []
|
39 |
+
for layer_idx in range(config.num_hidden_layers):
|
40 |
if layer_idx in self.cross_attention_layers:
|
41 |
+
layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
|
42 |
else:
|
43 |
+
layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
|
44 |
|
45 |
self.layers = nn.ModuleList(layers)
|
46 |
+
self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
47 |
+
self.rotary_emb = MllamaRotaryEmbedding(config=config)
|
48 |
self.gradient_checkpointing = False
|
49 |
self.post_init()
|
50 |
|
|
|
56 |
|
57 |
def forward(
|
58 |
self,
|
59 |
+
#input_ids: Optional[torch.LongTensor] = None,
|
60 |
+
#audio_features: Optional[torch.Tensor] = None,
|
61 |
attention_mask: Optional[torch.Tensor] = None,
|
62 |
position_ids: Optional[torch.LongTensor] = None,
|
63 |
cross_attention_states: Optional[torch.FloatTensor] = None,
|
|
|
93 |
torch.Size([1, 13, 4096])
|
94 |
```
|
95 |
"""
|
96 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
97 |
output_hidden_states = (
|
98 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
99 |
)
|
100 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
101 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
102 |
|
103 |
+
#if (input_ids is None) ^ (inputs_embeds is not None):
|
104 |
+
# raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
105 |
|
106 |
if self.gradient_checkpointing and self.training and use_cache:
|
107 |
logger.warning_once(
|
|
|
109 |
)
|
110 |
use_cache = False
|
111 |
|
112 |
+
#if inputs_embeds is None:
|
113 |
+
# inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
114 |
+
|
115 |
|
116 |
hidden_states = inputs_embeds
|
117 |
|
|
|
214 |
past_key_values: Cache,
|
215 |
output_attentions: bool,
|
216 |
):
|
217 |
+
if self.config._attn_implementation == "flash_attention_2":
|
218 |
if attention_mask is not None and 0.0 in attention_mask:
|
219 |
return attention_mask
|
220 |
return None
|
|
|
258 |
)
|
259 |
|
260 |
if (
|
261 |
+
self.config._attn_implementation == "sdpa"
|
262 |
and attention_mask is not None
|
263 |
and attention_mask.device.type == "cuda"
|
264 |
and not output_attentions
|
|
|
306 |
|
307 |
return causal_mask
|
308 |
|
309 |
+
class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
310 |
+
config_class = MllamaTextConfig
|
311 |
base_model_prefix = "model"
|
312 |
+
#_tied_weights_keys = ["lm_head.weight"]
|
313 |
|
314 |
+
def __init__(self, config: MllamaTextConfig):
|
315 |
super().__init__(config)
|
316 |
+
self.vocab_size = config.vocab_size
|
|
|
317 |
self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
|
318 |
+
self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
|
319 |
|
320 |
self.post_init()
|
321 |
|
322 |
def get_input_embeddings(self):
|
323 |
+
#return self.model.embed_tokens.text_embeddings
|
324 |
+
return None
|
325 |
|
326 |
def set_input_embeddings(self, value):
|
327 |
+
#self.model.embed_tokens.text_embeddings = value
|
328 |
+
pass
|
329 |
|
330 |
def get_output_embeddings(self):
|
331 |
return self.lm_head
|
|
|
341 |
|
342 |
def forward(
|
343 |
self,
|
344 |
+
#input_ids: torch.LongTensor = None,
|
345 |
attention_mask: Optional[torch.Tensor] = None,
|
346 |
position_ids: Optional[torch.LongTensor] = None,
|
347 |
cross_attention_states: Optional[torch.LongTensor] = None,
|
|
|
358 |
num_logits_to_keep: int = 0,
|
359 |
**loss_kwargs,
|
360 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
361 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
362 |
output_hidden_states = (
|
363 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
364 |
)
|
365 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
366 |
|
367 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
368 |
outputs = self.model(
|
369 |
+
#input_ids=input_ids,
|
370 |
cross_attention_states=cross_attention_states,
|
371 |
attention_mask=attention_mask,
|
372 |
position_ids=position_ids,
|
|
|
403 |
AutoModelForCausalLM.register(Llama3Config, Llama3ForCausalLM)
|
404 |
transformers.Llama3ForCausalLM = Llama3ForCausalLM
|
405 |
|
406 |
+
class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
407 |
+
config_class = Llama3Config
|
408 |
+
base_model_prefix = "model"
|
409 |
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
|
410 |
|
411 |
def __init__(self, config: Llama3Config):
|
|
|
418 |
|
419 |
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
|
420 |
self.language_model = Llama3ForCausalLM._from_config(config)
|
421 |
+
self.embeddings = Llama3Embedding(config.audio_config, config.text_config)
|
422 |
self.multi_modal_projector = nn.Linear(
|
423 |
config.vision_config.vision_output_dim,
|
424 |
config.text_config.hidden_size,
|