AlexHung29629 commited on
Commit
1cccb9f
·
verified ·
1 Parent(s): ba17429

Update modeling_llama3.py

Browse files
Files changed (1) hide show
  1. modeling_llama3.py +9 -6
modeling_llama3.py CHANGED
@@ -49,10 +49,12 @@ class Llama3TextModel(MllamaPreTrainedModel):
49
  self.post_init()
50
 
51
  def get_input_embeddings(self):
52
- return self.embed_tokens.text_embeddings
 
53
 
54
  def set_input_embeddings(self, value):
55
- self.embed_tokens.text_embeddings = value
 
56
 
57
  def forward(
58
  self,
@@ -312,10 +314,11 @@ class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
312
  #_tied_weights_keys = ["lm_head.weight"]
313
 
314
  def __init__(self, config: MllamaTextConfig):
315
- super().__init__(config)
316
- self.vocab_size = config.vocab_size
317
- self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
318
- self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
 
319
 
320
  self.post_init()
321
 
 
49
  self.post_init()
50
 
51
  def get_input_embeddings(self):
52
+ #return self.embed_tokens.text_embeddings
53
+ return None
54
 
55
  def set_input_embeddings(self, value):
56
+ #self.embed_tokens.text_embeddings = value
57
+ pass
58
 
59
  def forward(
60
  self,
 
314
  #_tied_weights_keys = ["lm_head.weight"]
315
 
316
  def __init__(self, config: MllamaTextConfig):
317
+ super().__init__(config.get_text_config())
318
+ self.text_config = config.get_text_config()
319
+ self.vocab_size = self.text_config.vocab_size
320
+ self.model = Llama3TextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
321
+ self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
322
 
323
  self.post_init()
324