Update modeling_llama3.py
Browse files- modeling_llama3.py +2 -2
modeling_llama3.py
CHANGED
@@ -331,12 +331,12 @@ class Llama3TextModel(Llama3PreTrainedModel):
|
|
331 |
return causal_mask
|
332 |
|
333 |
class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
334 |
-
config_class =
|
335 |
base_model_prefix = "model"
|
336 |
_tied_weights_keys = ["lm_head.weight"]
|
337 |
|
338 |
def __init__(self, config):
|
339 |
-
super().__init__(config
|
340 |
self.text_config = config.get_text_config()
|
341 |
self.vocab_size = self.text_config.vocab_size
|
342 |
self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
|
|
|
331 |
return causal_mask
|
332 |
|
333 |
class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
334 |
+
config_class = Llama3Config
|
335 |
base_model_prefix = "model"
|
336 |
_tied_weights_keys = ["lm_head.weight"]
|
337 |
|
338 |
def __init__(self, config):
|
339 |
+
super().__init__(config)
|
340 |
self.text_config = config.get_text_config()
|
341 |
self.vocab_size = self.text_config.vocab_size
|
342 |
self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
|