Upload modeling_eurobert.py (#15)
Browse files- Upload modeling_eurobert.py (50bf12545ca008d3f26f0fa9064b3b687593cd2f)
Co-authored-by: Hippolyte Gisserot-Boukhlef <[email protected]>
- modeling_eurobert.py +86 -7
modeling_eurobert.py
CHANGED
|
@@ -30,7 +30,7 @@ from transformers.activations import ACT2FN
|
|
| 30 |
from transformers.cache_utils import Cache, StaticCache
|
| 31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
-
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput
|
| 34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
from transformers.processing_utils import Unpack
|
|
@@ -708,7 +708,7 @@ class EuroBertModel(EuroBertPreTrainedModel):
|
|
| 708 |
|
| 709 |
|
| 710 |
@add_start_docstrings(
|
| 711 |
-
"The EuroBert Model with a
|
| 712 |
EUROBERT_START_DOCSTRING,
|
| 713 |
)
|
| 714 |
class EuroBertForMaskedLM(EuroBertPreTrainedModel):
|
|
@@ -766,7 +766,7 @@ class EuroBertForMaskedLM(EuroBertPreTrainedModel):
|
|
| 766 |
|
| 767 |
|
| 768 |
@add_start_docstrings(
|
| 769 |
-
"The EuroBert Model with a
|
| 770 |
EUROBERT_START_DOCSTRING,
|
| 771 |
)
|
| 772 |
class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
@@ -778,7 +778,7 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
| 778 |
self.model = EuroBertModel(config)
|
| 779 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 780 |
self.activation = nn.GELU()
|
| 781 |
-
self.
|
| 782 |
self.post_init()
|
| 783 |
|
| 784 |
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
|
@@ -830,12 +830,12 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
| 830 |
|
| 831 |
pooled_output = self.dense(pooled_output)
|
| 832 |
pooled_output = self.activation(pooled_output)
|
| 833 |
-
logits = self.
|
| 834 |
|
| 835 |
elif self.clf_pooling == "late":
|
| 836 |
x = self.dense(last_hidden_state)
|
| 837 |
x = self.activation(x)
|
| 838 |
-
logits = self.
|
| 839 |
if attention_mask is None:
|
| 840 |
logits = logits.mean(dim=1)
|
| 841 |
else:
|
|
@@ -878,4 +878,83 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
| 878 |
)
|
| 879 |
|
| 880 |
|
| 881 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
from transformers.cache_utils import Cache, StaticCache
|
| 31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
| 34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
from transformers.processing_utils import Unpack
|
|
|
|
| 708 |
|
| 709 |
|
| 710 |
@add_start_docstrings(
|
| 711 |
+
"The EuroBert Model with a decoder head on top that is used for masked language modeling.",
|
| 712 |
EUROBERT_START_DOCSTRING,
|
| 713 |
)
|
| 714 |
class EuroBertForMaskedLM(EuroBertPreTrainedModel):
|
|
|
|
| 766 |
|
| 767 |
|
| 768 |
@add_start_docstrings(
|
| 769 |
+
"The EuroBert Model with a sequence classification head on top that performs pooling.",
|
| 770 |
EUROBERT_START_DOCSTRING,
|
| 771 |
)
|
| 772 |
class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
|
|
| 778 |
self.model = EuroBertModel(config)
|
| 779 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 780 |
self.activation = nn.GELU()
|
| 781 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
| 782 |
self.post_init()
|
| 783 |
|
| 784 |
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
|
|
|
| 830 |
|
| 831 |
pooled_output = self.dense(pooled_output)
|
| 832 |
pooled_output = self.activation(pooled_output)
|
| 833 |
+
logits = self.classifier(pooled_output)
|
| 834 |
|
| 835 |
elif self.clf_pooling == "late":
|
| 836 |
x = self.dense(last_hidden_state)
|
| 837 |
x = self.activation(x)
|
| 838 |
+
logits = self.classifier(x)
|
| 839 |
if attention_mask is None:
|
| 840 |
logits = logits.mean(dim=1)
|
| 841 |
else:
|
|
|
|
| 878 |
)
|
| 879 |
|
| 880 |
|
| 881 |
+
@add_start_docstrings(
|
| 882 |
+
"""
|
| 883 |
+
The EuroBert Model with a token classification head on top (a linear layer on top of the hidden-states
|
| 884 |
+
output) e.g. for Named-Entity-Recognition (NER) tasks."
|
| 885 |
+
""",
|
| 886 |
+
EUROBERT_START_DOCSTRING,
|
| 887 |
+
)
|
| 888 |
+
class EuroBertForTokenClassification(EuroBertPreTrainedModel):
|
| 889 |
+
def __init__(self, config: EuroBertConfig):
|
| 890 |
+
super().__init__(config)
|
| 891 |
+
self.num_labels = config.num_labels
|
| 892 |
+
self.model = EuroBertModel(config)
|
| 893 |
+
|
| 894 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 895 |
+
self.post_init()
|
| 896 |
+
|
| 897 |
+
def get_input_embeddings(self):
|
| 898 |
+
return self.model.embed_tokens
|
| 899 |
+
|
| 900 |
+
def set_input_embeddings(self, value):
|
| 901 |
+
self.model.embed_tokens = value
|
| 902 |
+
|
| 903 |
+
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
| 904 |
+
def forward(
|
| 905 |
+
self,
|
| 906 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 907 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 908 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 909 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 910 |
+
labels: Optional[torch.LongTensor] = None,
|
| 911 |
+
use_cache: Optional[bool] = None,
|
| 912 |
+
output_attentions: Optional[bool] = None,
|
| 913 |
+
output_hidden_states: Optional[bool] = None,
|
| 914 |
+
return_dict: Optional[bool] = None,
|
| 915 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 916 |
+
r"""
|
| 917 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 918 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 919 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 920 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 921 |
+
"""
|
| 922 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 923 |
+
|
| 924 |
+
outputs = self.model(
|
| 925 |
+
input_ids,
|
| 926 |
+
attention_mask=attention_mask,
|
| 927 |
+
position_ids=position_ids,
|
| 928 |
+
inputs_embeds=inputs_embeds,
|
| 929 |
+
use_cache=use_cache,
|
| 930 |
+
output_attentions=output_attentions,
|
| 931 |
+
output_hidden_states=output_hidden_states,
|
| 932 |
+
return_dict=return_dict,
|
| 933 |
+
)
|
| 934 |
+
sequence_output = outputs[0]
|
| 935 |
+
logits = self.classifier(sequence_output)
|
| 936 |
+
|
| 937 |
+
loss = None
|
| 938 |
+
if labels is not None:
|
| 939 |
+
loss_fct = CrossEntropyLoss()
|
| 940 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 941 |
+
|
| 942 |
+
if not return_dict:
|
| 943 |
+
output = (logits,) + outputs[2:]
|
| 944 |
+
return ((loss,) + output) if loss is not None else output
|
| 945 |
+
|
| 946 |
+
return TokenClassifierOutput(
|
| 947 |
+
loss=loss,
|
| 948 |
+
logits=logits,
|
| 949 |
+
hidden_states=outputs.hidden_states,
|
| 950 |
+
attentions=outputs.attentions,
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
__all__ = [
|
| 955 |
+
"EuroBertPreTrainedModel",
|
| 956 |
+
"EuroBertModel",
|
| 957 |
+
"EuroBertForMaskedLM",
|
| 958 |
+
"EuroBertForSequenceClassification",
|
| 959 |
+
"EuroBertForTokenClassification",
|
| 960 |
+
]
|