AlexHung29629 commited on
Commit
36c6eac
·
verified ·
1 Parent(s): 8cf6588

Update modeling_llama3.py

Browse files
Files changed (1) hide show
  1. modeling_llama3.py +2 -2
modeling_llama3.py CHANGED
@@ -331,12 +331,12 @@ class Llama3TextModel(Llama3PreTrainedModel):
331
  return causal_mask
332
 
333
  class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
334
- config_class = Llama3ConfigConfig
335
  base_model_prefix = "model"
336
  _tied_weights_keys = ["lm_head.weight"]
337
 
338
  def __init__(self, config):
339
- super().__init__(config.get_text_config())
340
  self.text_config = config.get_text_config()
341
  self.vocab_size = self.text_config.vocab_size
342
  self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
 
331
  return causal_mask
332
 
333
  class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
334
+ config_class = Llama3Config
335
  base_model_prefix = "model"
336
  _tied_weights_keys = ["lm_head.weight"]
337
 
338
  def __init__(self, config):
339
+ super().__init__(config)
340
  self.text_config = config.get_text_config()
341
  self.vocab_size = self.text_config.vocab_size
342
  self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)