AlexHung29629 commited on
Commit
b0aaa43
·
verified ·
1 Parent(s): 9192dd9

Update modeling_llama3.py

Browse files
Files changed (1) hide show
  1. modeling_llama3.py +44 -40
modeling_llama3.py CHANGED
@@ -6,7 +6,7 @@ import torch.utils.checkpoint
6
  from torch import nn
7
 
8
  import transformers
9
- from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel, AutoModelForCausalLM
10
  from transformers.cache_utils import Cache, StaticCache
11
  from transformers.generation import GenerationMixin
12
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -14,7 +14,6 @@ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutpu
14
  from transformers.utils import logging
15
  from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
16
  from transformers.models.mllama.configuration_mllama import MllamaTextConfig
17
-
18
  from .configuration_llama3 import Llama3Config
19
  from .mllama_audio_model import Llama3Embedding
20
 
@@ -25,27 +24,27 @@ class Llama3PreTrainedModel(MllamaPreTrainedModel):
25
  config_class = Llama3Config
26
  base_model_prefix = "model"
27
 
28
- class Llama3TextModel(Llama3PreTrainedModel):
29
- config_class = Llama3Config
30
  base_model_prefix = "language_model.model"
31
 
32
- def __init__(self, config: Llama3Config):
33
  super().__init__(config)
34
- self.padding_idx = config.text_config.pad_token_id
35
- self.vocab_size = config.text_config.vocab_size
36
- self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
37
- self.cross_attention_layers = config.text_config.cross_attention_layers
38
 
39
  layers = []
40
- for layer_idx in range(config.text_config.num_hidden_layers):
41
  if layer_idx in self.cross_attention_layers:
42
- layers.append(MllamaCrossAttentionDecoderLayer(config.text_config, layer_idx))
43
  else:
44
- layers.append(MllamaSelfAttentionDecoderLayer(config.text_config, layer_idx))
45
 
46
  self.layers = nn.ModuleList(layers)
47
- self.norm = MllamaTextRMSNorm(config.text_config.hidden_size, eps=config.text_config.rms_norm_eps)
48
- self.rotary_emb = MllamaRotaryEmbedding(config=config.text_config)
49
  self.gradient_checkpointing = False
50
  self.post_init()
51
 
@@ -57,8 +56,8 @@ class Llama3TextModel(Llama3PreTrainedModel):
57
 
58
  def forward(
59
  self,
60
- input_ids: Optional[torch.LongTensor] = None,
61
- audio_features: Optional[torch.Tensor] = None,
62
  attention_mask: Optional[torch.Tensor] = None,
63
  position_ids: Optional[torch.LongTensor] = None,
64
  cross_attention_states: Optional[torch.FloatTensor] = None,
@@ -94,15 +93,15 @@ class Llama3TextModel(Llama3PreTrainedModel):
94
  torch.Size([1, 13, 4096])
95
  ```
96
  """
97
- output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions
98
  output_hidden_states = (
99
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
100
  )
101
- use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
102
- return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict
103
 
104
- if (input_ids is None) ^ (inputs_embeds is not None):
105
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
106
 
107
  if self.gradient_checkpointing and self.training and use_cache:
108
  logger.warning_once(
@@ -110,8 +109,9 @@ class Llama3TextModel(Llama3PreTrainedModel):
110
  )
111
  use_cache = False
112
 
113
- if inputs_embeds is None:
114
- inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
 
115
 
116
  hidden_states = inputs_embeds
117
 
@@ -214,7 +214,7 @@ class Llama3TextModel(Llama3PreTrainedModel):
214
  past_key_values: Cache,
215
  output_attentions: bool,
216
  ):
217
- if self.config.text_config._attn_implementation == "flash_attention_2":
218
  if attention_mask is not None and 0.0 in attention_mask:
219
  return attention_mask
220
  return None
@@ -258,7 +258,7 @@ class Llama3TextModel(Llama3PreTrainedModel):
258
  )
259
 
260
  if (
261
- self.config.text_config._attn_implementation == "sdpa"
262
  and attention_mask is not None
263
  and attention_mask.device.type == "cuda"
264
  and not output_attentions
@@ -306,25 +306,26 @@ class Llama3TextModel(Llama3PreTrainedModel):
306
 
307
  return causal_mask
308
 
309
- class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
310
- config_class = Llama3Config
311
  base_model_prefix = "model"
312
- _tied_weights_keys = ["lm_head.weight"]
313
 
314
- def __init__(self, config):
315
  super().__init__(config)
316
- self.text_config = config.get_text_config()
317
- self.vocab_size = self.text_config.vocab_size
318
  self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
319
- self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
320
 
321
  self.post_init()
322
 
323
  def get_input_embeddings(self):
324
- return self.model.embed_tokens.text_embeddings
 
325
 
326
  def set_input_embeddings(self, value):
327
- self.model.embed_tokens.text_embeddings = value
 
328
 
329
  def get_output_embeddings(self):
330
  return self.lm_head
@@ -340,7 +341,7 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
340
 
341
  def forward(
342
  self,
343
- input_ids: torch.LongTensor = None,
344
  attention_mask: Optional[torch.Tensor] = None,
345
  position_ids: Optional[torch.LongTensor] = None,
346
  cross_attention_states: Optional[torch.LongTensor] = None,
@@ -357,15 +358,15 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
357
  num_logits_to_keep: int = 0,
358
  **loss_kwargs,
359
  ) -> Union[Tuple, CausalLMOutputWithPast]:
360
- output_attentions = output_attentions if output_attentions is not None else self.text_config.output_attentions
361
  output_hidden_states = (
362
- output_hidden_states if output_hidden_states is not None else self.text_config.output_hidden_states
363
  )
364
- return_dict = return_dict if return_dict is not None else self.text_config.use_return_dict
365
 
366
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
367
  outputs = self.model(
368
- input_ids=input_ids,
369
  cross_attention_states=cross_attention_states,
370
  attention_mask=attention_mask,
371
  position_ids=position_ids,
@@ -402,7 +403,9 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
402
  AutoModelForCausalLM.register(Llama3Config, Llama3ForCausalLM)
403
  transformers.Llama3ForCausalLM = Llama3ForCausalLM
404
 
405
- class Llama3ForConditionalGeneration(Llama3PreTrainedModel, GenerationMixin):
 
 
406
  _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
407
 
408
  def __init__(self, config: Llama3Config):
@@ -415,6 +418,7 @@ class Llama3ForConditionalGeneration(Llama3PreTrainedModel, GenerationMixin):
415
 
416
  self.vision_model = MllamaVisionModel._from_config(config.vision_config)
417
  self.language_model = Llama3ForCausalLM._from_config(config)
 
418
  self.multi_modal_projector = nn.Linear(
419
  config.vision_config.vision_output_dim,
420
  config.text_config.hidden_size,
 
6
  from torch import nn
7
 
8
  import transformers
9
+ from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, Wav2Vec2BertConfig, AutoModel, AutoModelForCausalLM
10
  from transformers.cache_utils import Cache, StaticCache
11
  from transformers.generation import GenerationMixin
12
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
14
  from transformers.utils import logging
15
  from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
16
  from transformers.models.mllama.configuration_mllama import MllamaTextConfig
 
17
  from .configuration_llama3 import Llama3Config
18
  from .mllama_audio_model import Llama3Embedding
19
 
 
24
  config_class = Llama3Config
25
  base_model_prefix = "model"
26
 
27
+ class Llama3TextModel(MllamaPreTrainedModel):
28
+ config_class = MllamaTextConfig
29
  base_model_prefix = "language_model.model"
30
 
31
+ def __init__(self, config: MllamaTextConfig, audio_config: Wav2Vec2BertConfig):
32
  super().__init__(config)
33
+ self.padding_idx = config.pad_token_id
34
+ self.vocab_size = config.vocab_size
35
+ #self.embed_tokens = Llama3Embedding(audio_config, config)
36
+ self.cross_attention_layers = config.cross_attention_layers
37
 
38
  layers = []
39
+ for layer_idx in range(config.num_hidden_layers):
40
  if layer_idx in self.cross_attention_layers:
41
+ layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
42
  else:
43
+ layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
44
 
45
  self.layers = nn.ModuleList(layers)
46
+ self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
47
+ self.rotary_emb = MllamaRotaryEmbedding(config=config)
48
  self.gradient_checkpointing = False
49
  self.post_init()
50
 
 
56
 
57
  def forward(
58
  self,
59
+ #input_ids: Optional[torch.LongTensor] = None,
60
+ #audio_features: Optional[torch.Tensor] = None,
61
  attention_mask: Optional[torch.Tensor] = None,
62
  position_ids: Optional[torch.LongTensor] = None,
63
  cross_attention_states: Optional[torch.FloatTensor] = None,
 
93
  torch.Size([1, 13, 4096])
94
  ```
95
  """
96
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
97
  output_hidden_states = (
98
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
99
  )
100
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
101
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
102
 
103
+ #if (input_ids is None) ^ (inputs_embeds is not None):
104
+ # raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
105
 
106
  if self.gradient_checkpointing and self.training and use_cache:
107
  logger.warning_once(
 
109
  )
110
  use_cache = False
111
 
112
+ #if inputs_embeds is None:
113
+ # inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
114
+
115
 
116
  hidden_states = inputs_embeds
117
 
 
214
  past_key_values: Cache,
215
  output_attentions: bool,
216
  ):
217
+ if self.config._attn_implementation == "flash_attention_2":
218
  if attention_mask is not None and 0.0 in attention_mask:
219
  return attention_mask
220
  return None
 
258
  )
259
 
260
  if (
261
+ self.config._attn_implementation == "sdpa"
262
  and attention_mask is not None
263
  and attention_mask.device.type == "cuda"
264
  and not output_attentions
 
306
 
307
  return causal_mask
308
 
309
+ class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
310
+ config_class = MllamaTextConfig
311
  base_model_prefix = "model"
312
+ #_tied_weights_keys = ["lm_head.weight"]
313
 
314
+ def __init__(self, config: MllamaTextConfig):
315
  super().__init__(config)
316
+ self.vocab_size = config.vocab_size
 
317
  self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
318
+ self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
319
 
320
  self.post_init()
321
 
322
  def get_input_embeddings(self):
323
+ #return self.model.embed_tokens.text_embeddings
324
+ return None
325
 
326
  def set_input_embeddings(self, value):
327
+ #self.model.embed_tokens.text_embeddings = value
328
+ pass
329
 
330
  def get_output_embeddings(self):
331
  return self.lm_head
 
341
 
342
  def forward(
343
  self,
344
+ #input_ids: torch.LongTensor = None,
345
  attention_mask: Optional[torch.Tensor] = None,
346
  position_ids: Optional[torch.LongTensor] = None,
347
  cross_attention_states: Optional[torch.LongTensor] = None,
 
358
  num_logits_to_keep: int = 0,
359
  **loss_kwargs,
360
  ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
  output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
  )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
 
367
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
368
  outputs = self.model(
369
+ #input_ids=input_ids,
370
  cross_attention_states=cross_attention_states,
371
  attention_mask=attention_mask,
372
  position_ids=position_ids,
 
403
  AutoModelForCausalLM.register(Llama3Config, Llama3ForCausalLM)
404
  transformers.Llama3ForCausalLM = Llama3ForCausalLM
405
 
406
+ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
407
+ config_class = Llama3Config
408
+ base_model_prefix = "model"
409
  _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
410
 
411
  def __init__(self, config: Llama3Config):
 
418
 
419
  self.vision_model = MllamaVisionModel._from_config(config.vision_config)
420
  self.language_model = Llama3ForCausalLM._from_config(config)
421
+ self.embeddings = Llama3Embedding(config.audio_config, config.text_config)
422
  self.multi_modal_projector = nn.Linear(
423
  config.vision_config.vision_output_dim,
424
  config.text_config.hidden_size,