AlexHung29629 commited on
Commit
19d6896
·
verified ·
1 Parent(s): 36c6eac

Update modeling_llama3.py

Browse files
Files changed (1) hide show
  1. modeling_llama3.py +2 -2
modeling_llama3.py CHANGED
@@ -341,7 +341,7 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
341
  self.vocab_size = self.text_config.vocab_size
342
  self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
343
  self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
344
-
345
  self.post_init()
346
 
347
  def get_input_embeddings(self):
@@ -409,7 +409,7 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
409
 
410
  loss = None
411
  if labels is not None:
412
- loss = ForCausalLMLoss(logits, labels, self.vocab_size, **loss_kwargs)
413
 
414
  if not return_dict:
415
  output = (logits,) + outputs[1:]
 
341
  self.vocab_size = self.text_config.vocab_size
342
  self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
343
  self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
344
+ self.loss_function = ForCausalLMLoss
345
  self.post_init()
346
 
347
  def get_input_embeddings(self):
 
409
 
410
  loss = None
411
  if labels is not None:
412
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
413
 
414
  if not return_dict:
415
  output = (logits,) + outputs[1:]