update modeling file
Browse files- modeling_camelidae.py +3 -0
modeling_camelidae.py
CHANGED
|
@@ -20,6 +20,7 @@
|
|
| 20 |
""" PyTorch LLaMA model."""
|
| 21 |
import math
|
| 22 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
import copy
|
|
@@ -53,6 +54,7 @@ logger = logging.get_logger(__name__)
|
|
| 53 |
_CONFIG_FOR_DOC = "CamelidaeConfig"
|
| 54 |
|
| 55 |
|
|
|
|
| 56 |
class MoEModelOutputWithPast(ModelOutput):
|
| 57 |
last_hidden_state: torch.FloatTensor = None
|
| 58 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
@@ -61,6 +63,7 @@ class MoEModelOutputWithPast(ModelOutput):
|
|
| 61 |
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
| 62 |
|
| 63 |
|
|
|
|
| 64 |
class MoECausalLMOutputWithPast(ModelOutput):
|
| 65 |
loss: Optional[torch.FloatTensor] = None
|
| 66 |
aux_loss: Optional[torch.FloatTensor] = None
|
|
|
|
| 20 |
""" PyTorch LLaMA model."""
|
| 21 |
import math
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
|
| 25 |
import numpy as np
|
| 26 |
import copy
|
|
|
|
| 54 |
_CONFIG_FOR_DOC = "CamelidaeConfig"
|
| 55 |
|
| 56 |
|
| 57 |
+
@dataclass
|
| 58 |
class MoEModelOutputWithPast(ModelOutput):
|
| 59 |
last_hidden_state: torch.FloatTensor = None
|
| 60 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
|
|
| 63 |
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
| 64 |
|
| 65 |
|
| 66 |
+
@dataclass
|
| 67 |
class MoECausalLMOutputWithPast(ModelOutput):
|
| 68 |
loss: Optional[torch.FloatTensor] = None
|
| 69 |
aux_loss: Optional[torch.FloatTensor] = None
|