Update modeling_llama3.py
Browse files- modeling_llama3.py +9 -6
modeling_llama3.py
CHANGED
@@ -49,10 +49,12 @@ class Llama3TextModel(MllamaPreTrainedModel):
|
|
49 |
self.post_init()
|
50 |
|
51 |
def get_input_embeddings(self):
|
52 |
-
return self.embed_tokens.text_embeddings
|
|
|
53 |
|
54 |
def set_input_embeddings(self, value):
|
55 |
-
self.embed_tokens.text_embeddings = value
|
|
|
56 |
|
57 |
def forward(
|
58 |
self,
|
@@ -312,10 +314,11 @@ class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
|
312 |
#_tied_weights_keys = ["lm_head.weight"]
|
313 |
|
314 |
def __init__(self, config: MllamaTextConfig):
|
315 |
-
super().__init__(config)
|
316 |
-
self.
|
317 |
-
self.
|
318 |
-
self.
|
|
|
319 |
|
320 |
self.post_init()
|
321 |
|
|
|
49 |
self.post_init()
|
50 |
|
51 |
def get_input_embeddings(self):
|
52 |
+
#return self.embed_tokens.text_embeddings
|
53 |
+
return None
|
54 |
|
55 |
def set_input_embeddings(self, value):
|
56 |
+
#self.embed_tokens.text_embeddings = value
|
57 |
+
pass
|
58 |
|
59 |
def forward(
|
60 |
self,
|
|
|
314 |
#_tied_weights_keys = ["lm_head.weight"]
|
315 |
|
316 |
def __init__(self, config: MllamaTextConfig):
|
317 |
+
super().__init__(config.get_text_config())
|
318 |
+
self.text_config = config.get_text_config()
|
319 |
+
self.vocab_size = self.text_config.vocab_size
|
320 |
+
self.model = Llama3TextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
|
321 |
+
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
|
322 |
|
323 |
self.post_init()
|
324 |
|