AlexHung29629 commited on
Commit
14befaf
·
verified ·
1 Parent(s): cddd5fd

Update modeling_llama3.py

Browse files
Files changed (1) hide show
  1. modeling_llama3.py +1 -11
modeling_llama3.py CHANGED
@@ -548,20 +548,10 @@ class Llama3ForConditionalGeneration(Llama3PreTrainedModel, GenerationMixin):
548
  if cross_attention_mask is not None and cache_position is not None:
549
  cross_attention_mask = cross_attention_mask[:, :, cache_position]
550
  full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
551
-
552
- if audio_features is not None:
553
- if input_ids is None:
554
- raise ValueError("You must provide `input_ids` if you pass `audio_features`.")
555
-
556
- inputs_embeds = self.audio_model(
557
- audio_features=audio_features,
558
- input_ids=input_ids,
559
- return_dict=False,
560
- )
561
- input_ids = None
562
 
563
  outputs = self.language_model(
564
  input_ids=input_ids,
 
565
  attention_mask=attention_mask,
566
  position_ids=position_ids,
567
  cross_attention_states=cross_attention_states,
 
548
  if cross_attention_mask is not None and cache_position is not None:
549
  cross_attention_mask = cross_attention_mask[:, :, cache_position]
550
  full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
 
 
 
 
 
 
 
 
 
 
 
551
 
552
  outputs = self.language_model(
553
  input_ids=input_ids,
554
+ audio_features=audio_features,
555
  attention_mask=attention_mask,
556
  position_ids=position_ids,
557
  cross_attention_states=cross_attention_states,