Update modeling_llama3.py
Browse files- modeling_llama3.py +2 -2
modeling_llama3.py
CHANGED
@@ -341,7 +341,7 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
|
341 |
self.vocab_size = self.text_config.vocab_size
|
342 |
self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
|
343 |
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
|
344 |
-
|
345 |
self.post_init()
|
346 |
|
347 |
def get_input_embeddings(self):
|
@@ -409,7 +409,7 @@ class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin):
|
|
409 |
|
410 |
loss = None
|
411 |
if labels is not None:
|
412 |
-
loss =
|
413 |
|
414 |
if not return_dict:
|
415 |
output = (logits,) + outputs[1:]
|
|
|
341 |
self.vocab_size = self.text_config.vocab_size
|
342 |
self.model = Llama3TextModel._from_config(config, attn_implementation=config._attn_implementation)
|
343 |
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
|
344 |
+
self.loss_function = ForCausalLMLoss
|
345 |
self.post_init()
|
346 |
|
347 |
def get_input_embeddings(self):
|
|
|
409 |
|
410 |
loss = None
|
411 |
if labels is not None:
|
412 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
413 |
|
414 |
if not return_dict:
|
415 |
output = (logits,) + outputs[1:]
|