Fill-Mask
Transformers
PyTorch
eurobert
code
custom_code
Nicolas-BZRD commited on
Commit
aba2d76
·
verified ·
1 Parent(s): 222b7b4

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +7 -0
  2. configuration_eurobert.py +216 -0
  3. modeling_eurobert.py +881 -0
config.json CHANGED
@@ -2,6 +2,13 @@
2
  "architectures": [
3
  "EuroBertForMaskedLM"
4
  ],
 
 
 
 
 
 
 
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "bos_token": "<|begin_of_text|>",
 
2
  "architectures": [
3
  "EuroBertForMaskedLM"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_eurobert.EuroBertConfig",
7
+ "AutoModel": "modeling_eurobert.EuroBertModel",
8
+ "AutoModelForPreTraining": "modeling_eurobert.EuroBertPreTrainedModel",
9
+ "AutoModelForMaskedLM": "modeling_eurobert.EuroBertForMaskedLM",
10
+ "AutoModelForSequenceClassification": "modeling_eurobert.EuroBertForSequenceClassification"
11
+ },
12
  "attention_bias": false,
13
  "attention_dropout": 0.0,
14
  "bos_token": "<|begin_of_text|>",
configuration_eurobert.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/eurobert/modular_eurobert.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_eurobert.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Nicolas Boizard, Duarte M. Alves, Hippolyte Gisserot-Boukhlef and the EuroBert team. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ from transformers.utils import logging
24
+ from transformers.models.llama import LlamaConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class EuroBertConfig(LlamaConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`EuroBertModel`]. It is used to instantiate an EuroBert
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the EuroBERT-210m.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+
40
+ Args:
41
+ vocab_size (`int`, *optional*, defaults to 128256):
42
+ Vocabulary size of the EuroBert model. Defines the number of different tokens that can be represented by the
43
+ `inputs_ids` passed when calling [`EuroBertModel`]
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ intermediate_size (`int`, *optional*, defaults to 3072):
47
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
48
+ num_hidden_layers (`int`, *optional*, defaults to 12):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 12):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ num_key_value_heads (`int`, *optional*):
53
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
+ by meanpooling all the original heads within that group. For more details checkout [this
58
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
59
+ `num_attention_heads`.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
61
+ The non-linear activation function (function or string) in the encoder and pooler.
62
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
63
+ The maximum sequence length that this model might ever be used with. EuroBert supports up to 8192 tokens,
64
+ EuroBert-pretrained up to 2048.
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
68
+ The epsilon used by the rms normalization layers.
69
+ bos_token_id (`int`, *optional*, defaults to 128000):
70
+ Beginning of stream token id.
71
+ eos_token_id (`int`, *optional*, defaults to 128001):
72
+ End of stream token id.
73
+ pad_token_id (`int`, *optional*, defaults to 128001):
74
+ Padding token id.
75
+ mask_token_id (`int`, *optional*, defaults to 128002):
76
+ Mask token id.
77
+ pretraining_tp (`int`, *optional*, defaults to 1):
78
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
79
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
80
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
81
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
82
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
83
+ Whether to tie weight embeddings
84
+ rope_theta (`float`, *optional*, defaults to 250000.0):
85
+ The base period of the RoPE embeddings. EuroBert used base period of 250000.0,
86
+ EuroBert-pretrained 10000.0.
87
+ rope_scaling (`Dict`, *optional*):
88
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
89
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
90
+ accordingly.
91
+ Expected contents:
92
+ `rope_type` (`str`):
93
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
94
+ 'eurobert3'], with 'default' being the original RoPE implementation.
95
+ `factor` (`float`, *optional*):
96
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
97
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
98
+ original maximum pre-trained length.
99
+ `original_max_position_embeddings` (`int`, *optional*):
100
+ Used with 'dynamic', 'longrope' and 'eurobert3'. The original max position embeddings used during
101
+ pretraining.
102
+ `attention_factor` (`float`, *optional*):
103
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
104
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
105
+ `factor` field to infer the suggested value.
106
+ `beta_fast` (`float`, *optional*):
107
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
108
+ ramp function. If unspecified, it defaults to 32.
109
+ `beta_slow` (`float`, *optional*):
110
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
111
+ ramp function. If unspecified, it defaults to 1.
112
+ `short_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `long_factor` (`List[float]`, *optional*):
117
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
118
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
119
+ size divided by the number of attention heads divided by 2
120
+ `low_freq_factor` (`float`, *optional*):
121
+ Only used with 'eurobert3'. Scaling factor applied to low frequency components of the RoPE
122
+ `high_freq_factor` (`float`, *optional*):
123
+ Only used with 'eurobert3'. Scaling factor applied to high frequency components of the RoPE
124
+ attention_bias (`bool`, *optional*, defaults to `False`):
125
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
126
+ attention_dropout (`float`, *optional*, defaults to 0.0):
127
+ The dropout ratio for the attention probabilities.
128
+ mlp_bias (`bool`, *optional*, defaults to `False`):
129
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
130
+ head_dim (`int`, *optional*):
131
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
132
+ classifier_pooling (`str`, *optional*, defaults to `"late"`):
133
+ The pooling strategy to use for the classifier. Can be one of ['bos', 'mean', 'late'].
134
+
135
+ ```python
136
+ >>> from transformers import EuroBertModel, EuroBertConfig
137
+
138
+ >>> # Initializing a EuroBert eurobert-base style configuration
139
+ >>> configuration = EuroBertConfig()
140
+
141
+ >>> # Initializing a model from the eurobert-base style configuration
142
+ >>> model = EuroBertModel(configuration)
143
+
144
+ >>> # Accessing the model configuration
145
+ >>> configuration = model.config
146
+ ```"""
147
+
148
+ model_type = "eurobert"
149
+
150
+ def __init__(
151
+ self,
152
+ vocab_size=128256,
153
+ hidden_size=768,
154
+ intermediate_size=3072,
155
+ num_hidden_layers=12,
156
+ num_attention_heads=12,
157
+ num_key_value_heads=None,
158
+ hidden_act="silu",
159
+ max_position_embeddings=8192,
160
+ initializer_range=0.02,
161
+ rms_norm_eps=1e-05,
162
+ bos_token_id=128000,
163
+ eos_token_id=128001,
164
+ pad_token_id=128001,
165
+ mask_token_id=128002,
166
+ pretraining_tp=1,
167
+ tie_word_embeddings=False,
168
+ rope_theta=250000.0,
169
+ rope_scaling=None,
170
+ attention_bias=False,
171
+ attention_dropout=0.0,
172
+ mlp_bias=False,
173
+ head_dim=None,
174
+ classifier_pooling="late",
175
+ **kwargs,
176
+ ):
177
+ # use_cache is specific to decoder models and should be set to False for encoder models
178
+ use_cache = kwargs.pop("use_cache", None)
179
+ if use_cache:
180
+ logger.warning_once(
181
+ "The `use_cache` argument to EuroBertConfig is set to `False`, as caching is never used for encoder models."
182
+ )
183
+
184
+ if num_key_value_heads is None:
185
+ num_key_value_heads = num_attention_heads
186
+
187
+ super().__init__(
188
+ vocab_size=vocab_size,
189
+ hidden_size=hidden_size,
190
+ intermediate_size=intermediate_size,
191
+ num_hidden_layers=num_hidden_layers,
192
+ num_attention_heads=num_attention_heads,
193
+ num_key_value_heads=num_key_value_heads,
194
+ hidden_act=hidden_act,
195
+ max_position_embeddings=max_position_embeddings,
196
+ initializer_range=initializer_range,
197
+ rms_norm_eps=rms_norm_eps,
198
+ use_cache=False,
199
+ bos_token_id=bos_token_id,
200
+ eos_token_id=eos_token_id,
201
+ pad_token_id=pad_token_id,
202
+ pretraining_tp=pretraining_tp,
203
+ tie_word_embeddings=tie_word_embeddings,
204
+ rope_theta=rope_theta,
205
+ rope_scaling=rope_scaling,
206
+ attention_bias=attention_bias,
207
+ attention_dropout=attention_dropout,
208
+ mlp_bias=mlp_bias,
209
+ head_dim=head_dim,
210
+ **kwargs,
211
+ )
212
+ self.mask_token_id = mask_token_id
213
+ self.clf_pooling = classifier_pooling
214
+
215
+
216
+ __all__ = ["EuroBertConfig"]
modeling_eurobert.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/eurobert/modular_eurobert.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_eurobert.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Nicolas Boizard, Duarte M. Alves, Hippolyte Gisserot-Boukhlef and the EuroBert team. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ from typing import Callable, Optional, Tuple, Union
24
+
25
+ import torch
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, StaticCache
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.processing_utils import Unpack
37
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from .configuration_eurobert import EuroBertConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CHECKPOINT_FOR_DOC = "EuroBERT/EuroBERT-210m"
44
+ _CONFIG_FOR_DOC = "EuroBertConfig"
45
+
46
+
47
+ class EuroBertRMSNorm(nn.Module):
48
+ def __init__(self, hidden_size, eps=1e-5):
49
+ """
50
+ EuroBertRMSNorm is equivalent to T5LayerNorm
51
+ """
52
+ super().__init__()
53
+ self.weight = nn.Parameter(torch.ones(hidden_size))
54
+ self.variance_epsilon = eps
55
+
56
+ def forward(self, hidden_states):
57
+ input_dtype = hidden_states.dtype
58
+ hidden_states = hidden_states.to(torch.float32)
59
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
61
+ return self.weight * hidden_states.to(input_dtype)
62
+
63
+ def extra_repr(self):
64
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
65
+
66
+
67
+ def rotate_half(x):
68
+ """Rotates half the hidden dims of the input."""
69
+ x1 = x[..., : x.shape[-1] // 2]
70
+ x2 = x[..., x.shape[-1] // 2 :]
71
+ return torch.cat((-x2, x1), dim=-1)
72
+
73
+
74
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
75
+ """Applies Rotary Position Embedding to the query and key tensors.
76
+
77
+ Args:
78
+ q (`torch.Tensor`): The query tensor.
79
+ k (`torch.Tensor`): The key tensor.
80
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
81
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
82
+ position_ids (`torch.Tensor`, *optional*):
83
+ Deprecated and unused.
84
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
85
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
86
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
87
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
88
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
89
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
90
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
91
+ Returns:
92
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
93
+ """
94
+ cos = cos.unsqueeze(unsqueeze_dim)
95
+ sin = sin.unsqueeze(unsqueeze_dim)
96
+ q_embed = (q * cos) + (rotate_half(q) * sin)
97
+ k_embed = (k * cos) + (rotate_half(k) * sin)
98
+ return q_embed, k_embed
99
+
100
+
101
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
102
+ """
103
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
104
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
105
+ """
106
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
107
+ if n_rep == 1:
108
+ return hidden_states
109
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
110
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
111
+
112
+
113
+ def eager_attention_forward(
114
+ module: nn.Module,
115
+ query: torch.Tensor,
116
+ key: torch.Tensor,
117
+ value: torch.Tensor,
118
+ attention_mask: Optional[torch.Tensor],
119
+ scaling: float,
120
+ dropout: float = 0.0,
121
+ **kwargs,
122
+ ):
123
+ key_states = repeat_kv(key, module.num_key_value_groups)
124
+ value_states = repeat_kv(value, module.num_key_value_groups)
125
+
126
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
127
+ if attention_mask is not None:
128
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
129
+ attn_weights = attn_weights + causal_mask
130
+
131
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
132
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
133
+ attn_output = torch.matmul(attn_weights, value_states)
134
+ attn_output = attn_output.transpose(1, 2).contiguous()
135
+
136
+ return attn_output, attn_weights
137
+
138
+
139
+ class EuroBertAttention(nn.Module):
140
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
141
+
142
+ def __init__(self, config: EuroBertConfig, layer_idx: int):
143
+ super().__init__()
144
+ self.config = config
145
+ self.layer_idx = layer_idx
146
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
147
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
148
+ self.scaling = self.head_dim**-0.5
149
+ self.attention_dropout = config.attention_dropout
150
+ self.is_causal = False
151
+
152
+ self.q_proj = nn.Linear(
153
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
154
+ )
155
+ self.k_proj = nn.Linear(
156
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
157
+ )
158
+ self.v_proj = nn.Linear(
159
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
160
+ )
161
+ self.o_proj = nn.Linear(
162
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
163
+ )
164
+
165
+ def forward(
166
+ self,
167
+ hidden_states: torch.Tensor,
168
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
169
+ attention_mask: Optional[torch.Tensor],
170
+ **kwargs: Unpack[FlashAttentionKwargs],
171
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
172
+ input_shape = hidden_states.shape[:-1]
173
+ hidden_shape = (*input_shape, -1, self.head_dim)
174
+
175
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
176
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
177
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
178
+
179
+ cos, sin = position_embeddings
180
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
181
+
182
+ attention_interface: Callable = eager_attention_forward
183
+ if self.config._attn_implementation != "eager":
184
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
185
+ logger.warning_once(
186
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
187
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
188
+ )
189
+ else:
190
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
191
+
192
+ attn_output, attn_weights = attention_interface(
193
+ self,
194
+ query_states,
195
+ key_states,
196
+ value_states,
197
+ attention_mask,
198
+ dropout=0.0 if not self.training else self.attention_dropout,
199
+ scaling=self.scaling,
200
+ is_causal=False,
201
+ **kwargs,
202
+ )
203
+
204
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
205
+ attn_output = self.o_proj(attn_output)
206
+ return attn_output, attn_weights
207
+
208
+
209
+ EUROBERT_START_DOCSTRING = r"""
210
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
211
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
212
+ etc.)
213
+
214
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
215
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
216
+ and behavior.
217
+
218
+ Parameters:
219
+ config ([`EuroBertConfig`]):
220
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
221
+ load the weights associated with the model, only the configuration. Check out the
222
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
223
+ """
224
+
225
+
226
+ @add_start_docstrings(
227
+ "The bare ModernBert Model outputting raw hidden-states without any specific head on top.",
228
+ EUROBERT_START_DOCSTRING,
229
+ )
230
+ class EuroBertPreTrainedModel(PreTrainedModel):
231
+ config_class = EuroBertConfig
232
+ base_model_prefix = "model"
233
+ supports_gradient_checkpointing = True
234
+ _no_split_modules = ["EuroBertDecoderLayer"]
235
+ _skip_keys_device_placement = ["past_key_values"]
236
+ _supports_flash_attn_2 = True
237
+ _supports_sdpa = True
238
+ _supports_flex_attn = True
239
+ _supports_cache_class = True
240
+ _supports_quantized_cache = True
241
+ _supports_static_cache = True
242
+ _supports_attention_backend = True
243
+
244
+ def _init_weights(self, module):
245
+ std = self.config.initializer_range
246
+ if isinstance(module, nn.Linear):
247
+ module.weight.data.normal_(mean=0.0, std=std)
248
+ if module.bias is not None:
249
+ module.bias.data.zero_()
250
+ elif isinstance(module, nn.Embedding):
251
+ module.weight.data.normal_(mean=0.0, std=std)
252
+ if module.padding_idx is not None:
253
+ module.weight.data[module.padding_idx].zero_()
254
+
255
+
256
+ class EuroBertRotaryEmbedding(nn.Module):
257
+ def __init__(self, config: EuroBertConfig, device=None):
258
+ super().__init__()
259
+ # BC: "rope_type" was originally "type"
260
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
261
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
262
+ else:
263
+ self.rope_type = "default"
264
+ self.max_seq_len_cached = config.max_position_embeddings
265
+ self.original_max_seq_len = config.max_position_embeddings
266
+
267
+ self.config = config
268
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
269
+
270
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
271
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
272
+ self.original_inv_freq = self.inv_freq
273
+
274
+ def _dynamic_frequency_update(self, position_ids, device):
275
+ """
276
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
277
+ 1 - growing beyond the cached sequence length (allow scaling)
278
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
279
+ """
280
+ seq_len = torch.max(position_ids) + 1
281
+ if seq_len > self.max_seq_len_cached: # growth
282
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
283
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
284
+ self.max_seq_len_cached = seq_len
285
+
286
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
287
+ # This .to() is needed if the model has been moved to a device after being initialized (because
288
+ # the buffer is automatically moved, but not the original copy)
289
+ self.original_inv_freq = self.original_inv_freq.to(device)
290
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
291
+ self.max_seq_len_cached = self.original_max_seq_len
292
+
293
+ @torch.no_grad()
294
+ def forward(self, x, position_ids):
295
+ if "dynamic" in self.rope_type:
296
+ self._dynamic_frequency_update(position_ids, device=x.device)
297
+
298
+ # Core RoPE block
299
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
300
+ position_ids_expanded = position_ids[:, None, :].float()
301
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
302
+ device_type = x.device.type
303
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
304
+ with torch.autocast(device_type=device_type, enabled=False):
305
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
306
+ emb = torch.cat((freqs, freqs), dim=-1)
307
+ cos = emb.cos()
308
+ sin = emb.sin()
309
+
310
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
311
+ cos = cos * self.attention_scaling
312
+ sin = sin * self.attention_scaling
313
+
314
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
315
+
316
+
317
+ class EuroBertMLP(nn.Module):
318
+ def __init__(self, config):
319
+ super().__init__()
320
+ self.config = config
321
+ self.hidden_size = config.hidden_size
322
+ self.intermediate_size = config.intermediate_size
323
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
324
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
325
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
326
+ self.act_fn = ACT2FN[config.hidden_act]
327
+
328
+ def forward(self, x):
329
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
330
+ return down_proj
331
+
332
+
333
+ class EuroBertDecoderLayer(nn.Module):
334
+ def __init__(self, config: EuroBertConfig, layer_idx: int):
335
+ super().__init__()
336
+ self.hidden_size = config.hidden_size
337
+
338
+ self.self_attn = EuroBertAttention(config=config, layer_idx=layer_idx)
339
+
340
+ self.mlp = EuroBertMLP(config)
341
+ self.input_layernorm = EuroBertRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
342
+ self.post_attention_layernorm = EuroBertRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
343
+
344
+ def forward(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ position_ids: Optional[torch.LongTensor] = None,
349
+ past_key_value: Optional[Cache] = None,
350
+ output_attentions: Optional[bool] = False,
351
+ use_cache: Optional[bool] = False,
352
+ cache_position: Optional[torch.LongTensor] = None,
353
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
354
+ **kwargs: Unpack[FlashAttentionKwargs],
355
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
356
+ residual = hidden_states
357
+
358
+ hidden_states = self.input_layernorm(hidden_states)
359
+
360
+ # Self Attention
361
+ hidden_states, self_attn_weights = self.self_attn(
362
+ hidden_states=hidden_states,
363
+ attention_mask=attention_mask,
364
+ position_ids=position_ids,
365
+ past_key_value=past_key_value,
366
+ output_attentions=output_attentions,
367
+ use_cache=use_cache,
368
+ cache_position=cache_position,
369
+ position_embeddings=position_embeddings,
370
+ **kwargs,
371
+ )
372
+ hidden_states = residual + hidden_states
373
+
374
+ # Fully Connected
375
+ residual = hidden_states
376
+ hidden_states = self.post_attention_layernorm(hidden_states)
377
+ hidden_states = self.mlp(hidden_states)
378
+ hidden_states = residual + hidden_states
379
+
380
+ outputs = (hidden_states,)
381
+ if output_attentions:
382
+ outputs += (self_attn_weights,)
383
+
384
+ return outputs
385
+
386
+
387
+ EUROBERT_INPUTS_DOCSTRING = r"""
388
+ Args:
389
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
390
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
391
+ it.
392
+
393
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
394
+ [`PreTrainedTokenizer.__call__`] for details.
395
+
396
+ [What are input IDs?](../glossary#input-ids)
397
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
398
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
399
+
400
+ - 1 for tokens that are **not masked**,
401
+ - 0 for tokens that are **masked**.
402
+
403
+ [What are attention masks?](../glossary#attention-mask)
404
+
405
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
406
+ [`PreTrainedTokenizer.__call__`] for details.
407
+
408
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
409
+ `past_key_values`).
410
+
411
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
412
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
413
+ information on the default strategy.
414
+
415
+ - 1 indicates the head is **not masked**,
416
+ - 0 indicates the head is **masked**.
417
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
418
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
419
+ config.n_positions - 1]`.
420
+
421
+ [What are position IDs?](../glossary#position-ids)
422
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
423
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
424
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
425
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
426
+
427
+ Two formats are allowed:
428
+ - a [`~cache_utils.Cache`] instance, see our
429
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
430
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
431
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
432
+ cache format.
433
+
434
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
435
+ legacy cache format will be returned.
436
+
437
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
438
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
439
+ of shape `(batch_size, sequence_length)`.
440
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
441
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
442
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
443
+ model's internal embedding lookup matrix.
444
+ use_cache (`bool`, *optional*):
445
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
446
+ `past_key_values`).
447
+ output_attentions (`bool`, *optional*):
448
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
449
+ tensors for more detail.
450
+ output_hidden_states (`bool`, *optional*):
451
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
452
+ more detail.
453
+ return_dict (`bool`, *optional*):
454
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
455
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
456
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
457
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
458
+ the complete sequence length.
459
+ """
460
+
461
+
462
+ @add_start_docstrings(
463
+ "The bare EuroBert Model outputting raw hidden-states without any specific head on top.",
464
+ EUROBERT_START_DOCSTRING,
465
+ )
466
+ class EuroBertModel(EuroBertPreTrainedModel):
467
+ """
468
+ Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EuroBertDecoderLayer`]
469
+
470
+ Args:
471
+ config: EuroBertConfig
472
+ """
473
+
474
+ def __init__(self, config: EuroBertConfig):
475
+ super().__init__(config)
476
+ self.padding_idx = config.pad_token_id
477
+ self.vocab_size = config.vocab_size
478
+
479
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
480
+ self.layers = nn.ModuleList(
481
+ [EuroBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
482
+ )
483
+ self.norm = EuroBertRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
484
+ self.rotary_emb = EuroBertRotaryEmbedding(config=config)
485
+ self.gradient_checkpointing = False
486
+ self.mask_converter = AttentionMaskConverter(is_causal=False)
487
+
488
+ # Initialize weights and apply final processing
489
+ self.post_init()
490
+
491
+ def get_input_embeddings(self):
492
+ return self.embed_tokens
493
+
494
+ def set_input_embeddings(self, value):
495
+ self.embed_tokens = value
496
+
497
+ @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
498
+ @add_code_sample_docstrings(
499
+ checkpoint=_CHECKPOINT_FOR_DOC,
500
+ output_type=BaseModelOutput,
501
+ config_class=_CONFIG_FOR_DOC,
502
+ )
503
+ def forward(
504
+ self,
505
+ input_ids: torch.LongTensor = None,
506
+ attention_mask: Optional[torch.Tensor] = None,
507
+ position_ids: Optional[torch.LongTensor] = None,
508
+ inputs_embeds: Optional[torch.FloatTensor] = None,
509
+ output_attentions: Optional[bool] = None,
510
+ output_hidden_states: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
513
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
519
+
520
+ if (input_ids is None) ^ (inputs_embeds is not None):
521
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
522
+
523
+ if inputs_embeds is None:
524
+ inputs_embeds = self.embed_tokens(input_ids)
525
+
526
+ if attention_mask is not None:
527
+ mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
528
+ else:
529
+ mask = None
530
+
531
+ hidden_states = inputs_embeds
532
+
533
+ # create position embeddings to be shared across the encoder layers
534
+ if position_ids is None:
535
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
536
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
537
+
538
+ # encoder layers
539
+ all_hidden_states = () if output_hidden_states else None
540
+ all_self_attns = () if output_attentions else None
541
+
542
+ for encoder_layer in self.layers[: self.config.num_hidden_layers]:
543
+ if output_hidden_states:
544
+ all_hidden_states += (hidden_states,)
545
+
546
+ if self.gradient_checkpointing and self.training:
547
+ layer_outputs = self._gradient_checkpointing_func(
548
+ encoder_layer.__call__,
549
+ hidden_states,
550
+ mask,
551
+ position_ids,
552
+ None,
553
+ output_attentions,
554
+ False,
555
+ None,
556
+ position_embeddings,
557
+ )
558
+ else:
559
+ layer_outputs = encoder_layer(
560
+ hidden_states,
561
+ attention_mask=mask,
562
+ position_ids=position_ids,
563
+ output_attentions=output_attentions,
564
+ position_embeddings=position_embeddings,
565
+ **flash_attn_kwargs,
566
+ )
567
+
568
+ hidden_states = layer_outputs[0]
569
+
570
+ if output_attentions:
571
+ all_self_attns += (layer_outputs[1],)
572
+
573
+ hidden_states = self.norm(hidden_states)
574
+
575
+ # add hidden states from the last encoder layer
576
+ if output_hidden_states:
577
+ all_hidden_states += (hidden_states,)
578
+
579
+ output = BaseModelOutput(
580
+ last_hidden_state=hidden_states,
581
+ hidden_states=all_hidden_states,
582
+ attentions=all_self_attns,
583
+ )
584
+ return output if return_dict else output.to_tuple()
585
+
586
+ def _update_causal_mask(
587
+ self,
588
+ attention_mask: torch.Tensor,
589
+ input_tensor: torch.Tensor,
590
+ cache_position: torch.Tensor,
591
+ past_key_values: Cache,
592
+ output_attentions: bool,
593
+ ):
594
+ if self.config._attn_implementation == "flash_attention_2":
595
+ if attention_mask is not None and (attention_mask == 0.0).any():
596
+ return attention_mask
597
+ return None
598
+
599
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
600
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
601
+ # to infer the attention mask.
602
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
603
+ using_static_cache = isinstance(past_key_values, StaticCache)
604
+
605
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
606
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
607
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
608
+ attention_mask,
609
+ inputs_embeds=input_tensor,
610
+ past_key_values_length=past_seen_tokens,
611
+ is_training=self.training,
612
+ ):
613
+ return None
614
+
615
+ dtype, device = input_tensor.dtype, input_tensor.device
616
+ sequence_length = input_tensor.shape[1]
617
+ if using_static_cache:
618
+ target_length = past_key_values.get_max_cache_shape()
619
+ else:
620
+ target_length = (
621
+ attention_mask.shape[-1]
622
+ if isinstance(attention_mask, torch.Tensor)
623
+ else past_seen_tokens + sequence_length + 1
624
+ )
625
+
626
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
627
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
628
+ attention_mask,
629
+ sequence_length=sequence_length,
630
+ target_length=target_length,
631
+ dtype=dtype,
632
+ device=device,
633
+ cache_position=cache_position,
634
+ batch_size=input_tensor.shape[0],
635
+ )
636
+
637
+ if (
638
+ self.config._attn_implementation == "sdpa"
639
+ and attention_mask is not None
640
+ and attention_mask.device.type in ["cuda", "xpu"]
641
+ and not output_attentions
642
+ ):
643
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
644
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
645
+ # Details: https://github.com/pytorch/pytorch/issues/110213
646
+ min_dtype = torch.finfo(dtype).min
647
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
648
+
649
+ return causal_mask
650
+
651
+ @staticmethod
652
+ def _prepare_4d_causal_attention_mask_with_cache_position(
653
+ attention_mask: torch.Tensor,
654
+ sequence_length: int,
655
+ target_length: int,
656
+ dtype: torch.dtype,
657
+ device: torch.device,
658
+ cache_position: torch.Tensor,
659
+ batch_size: int,
660
+ **kwargs,
661
+ ):
662
+ """
663
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
664
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
665
+
666
+ Args:
667
+ attention_mask (`torch.Tensor`):
668
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
669
+ `(batch_size, 1, query_length, key_value_length)`.
670
+ sequence_length (`int`):
671
+ The sequence length being processed.
672
+ target_length (`int`):
673
+ The target length: when generating with static cache, the mask should be as long as the static cache,
674
+ to account for the 0 padding, the part of the cache that is not filled yet.
675
+ dtype (`torch.dtype`):
676
+ The dtype to use for the 4D attention mask.
677
+ device (`torch.device`):
678
+ The device to plcae the 4D attention mask on.
679
+ cache_position (`torch.Tensor`):
680
+ Indices depicting the position of the input sequence tokens in the sequence.
681
+ batch_size (`torch.Tensor`):
682
+ Batch size.
683
+ """
684
+ if attention_mask is not None and attention_mask.dim() == 4:
685
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
686
+ causal_mask = attention_mask
687
+ else:
688
+ min_dtype = torch.finfo(dtype).min
689
+ causal_mask = torch.full(
690
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
691
+ )
692
+ if sequence_length != 1:
693
+ causal_mask = torch.triu(causal_mask, diagonal=1)
694
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
695
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
696
+ if attention_mask is not None:
697
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
698
+ mask_length = attention_mask.shape[-1]
699
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
700
+ causal_mask.device
701
+ )
702
+ padding_mask = padding_mask == 0
703
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
704
+ padding_mask, min_dtype
705
+ )
706
+
707
+ return causal_mask
708
+
709
+
710
+ @add_start_docstrings(
711
+ "The EuroBert Model with a sequence classification head on top that performs pooling.",
712
+ EUROBERT_START_DOCSTRING,
713
+ )
714
+ class EuroBertForMaskedLM(EuroBertPreTrainedModel):
715
+ def __init__(self, config: EuroBertConfig):
716
+ super().__init__(config)
717
+ self.model = EuroBertModel(config)
718
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, config.mlp_bias)
719
+ self.post_init()
720
+
721
+ @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
722
+ @add_code_sample_docstrings(
723
+ checkpoint=_CHECKPOINT_FOR_DOC,
724
+ output_type=BaseModelOutput,
725
+ config_class=_CONFIG_FOR_DOC,
726
+ )
727
+ def forward(
728
+ self,
729
+ input_ids: Optional[torch.LongTensor] = None,
730
+ attention_mask: Optional[torch.Tensor] = None,
731
+ position_ids: Optional[torch.LongTensor] = None,
732
+ inputs_embeds: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
739
+
740
+ encoder_output = self.model(
741
+ input_ids,
742
+ attention_mask=attention_mask,
743
+ position_ids=position_ids,
744
+ inputs_embeds=inputs_embeds,
745
+ output_attentions=output_attentions,
746
+ output_hidden_states=output_hidden_states,
747
+ return_dict=return_dict,
748
+ )
749
+
750
+ prediction_scores = self.lm_head(encoder_output[0])
751
+ masked_lm_loss = None
752
+ if labels is not None:
753
+ labels = labels.to(prediction_scores.device)
754
+ masked_lm_loss = self.loss_function(prediction_scores, labels, vocab_size=self.config.vocab_size)
755
+
756
+ if not return_dict:
757
+ output = (prediction_scores,) + encoder_output[1:]
758
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
759
+
760
+ return MaskedLMOutput(
761
+ loss=masked_lm_loss,
762
+ logits=prediction_scores,
763
+ hidden_states=encoder_output.hidden_states,
764
+ attentions=encoder_output.attentions,
765
+ )
766
+
767
+
768
+ @add_start_docstrings(
769
+ "The EuroBert Model with a decoder head on top that is used for masked language modeling.",
770
+ EUROBERT_START_DOCSTRING,
771
+ )
772
+ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
773
+ def __init__(self, config: EuroBertConfig):
774
+ super().__init__(config)
775
+ self.num_labels = config.num_labels
776
+ self.clf_pooling = config.clf_pooling
777
+
778
+ self.model = EuroBertModel(config)
779
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
780
+ self.activation = nn.GELU()
781
+ self.out_proj = nn.Linear(config.hidden_size, self.num_labels)
782
+ self.post_init()
783
+
784
+ @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
785
+ @add_code_sample_docstrings(
786
+ checkpoint=_CHECKPOINT_FOR_DOC,
787
+ output_type=BaseModelOutput,
788
+ config_class=_CONFIG_FOR_DOC,
789
+ )
790
+ def forward(
791
+ self,
792
+ input_ids: Optional[torch.LongTensor] = None,
793
+ attention_mask: Optional[torch.Tensor] = None,
794
+ position_ids: Optional[torch.LongTensor] = None,
795
+ inputs_embeds: Optional[torch.FloatTensor] = None,
796
+ labels: Optional[torch.LongTensor] = None,
797
+ output_attentions: Optional[bool] = None,
798
+ output_hidden_states: Optional[bool] = None,
799
+ return_dict: Optional[bool] = None,
800
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
801
+ r"""
802
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
803
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
804
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
805
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
806
+ """
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ encoder_output = self.model(
810
+ input_ids,
811
+ attention_mask=attention_mask,
812
+ position_ids=position_ids,
813
+ inputs_embeds=inputs_embeds,
814
+ output_attentions=output_attentions,
815
+ output_hidden_states=output_hidden_states,
816
+ return_dict=return_dict,
817
+ )
818
+ last_hidden_state = encoder_output[0]
819
+
820
+ if self.clf_pooling in ["bos", "mean"]:
821
+ if self.clf_pooling == "bos":
822
+ pooled_output = last_hidden_state[:, 0]
823
+
824
+ elif self.clf_pooling == "mean":
825
+ if attention_mask is None:
826
+ pooled_output = last_hidden_state.mean(dim=1)
827
+ else:
828
+ pooled_output = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1)
829
+ pooled_output /= attention_mask.sum(dim=1, keepdim=True)
830
+
831
+ pooled_output = self.dense(pooled_output)
832
+ pooled_output = self.activation(pooled_output)
833
+ logits = self.out_proj(pooled_output)
834
+
835
+ elif self.clf_pooling == "late":
836
+ x = self.dense(last_hidden_state)
837
+ x = self.activation(x)
838
+ logits = self.out_proj(x)
839
+ if attention_mask is None:
840
+ logits = logits.mean(dim=1)
841
+ else:
842
+ logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1)
843
+ logits /= attention_mask.sum(dim=1, keepdim=True)
844
+
845
+ loss = None
846
+ if labels is not None:
847
+ labels = labels.to(logits.device)
848
+ if self.config.problem_type is None:
849
+ if self.num_labels == 1:
850
+ self.config.problem_type = "regression"
851
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
852
+ self.config.problem_type = "single_label_classification"
853
+ else:
854
+ self.config.problem_type = "multi_label_classification"
855
+
856
+ if self.config.problem_type == "regression":
857
+ loss_fct = MSELoss()
858
+ if self.num_labels == 1:
859
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
860
+ else:
861
+ loss = loss_fct(logits, labels)
862
+ elif self.config.problem_type == "single_label_classification":
863
+ loss_fct = CrossEntropyLoss()
864
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
865
+ elif self.config.problem_type == "multi_label_classification":
866
+ loss_fct = BCEWithLogitsLoss()
867
+ loss = loss_fct(logits, labels)
868
+
869
+ if not return_dict:
870
+ output = (logits,) + encoder_output[1:]
871
+ return ((loss,) + output) if loss is not None else output
872
+
873
+ return SequenceClassifierOutput(
874
+ loss=loss,
875
+ logits=logits,
876
+ hidden_states=encoder_output.hidden_states,
877
+ attentions=encoder_output.attentions,
878
+ )
879
+
880
+
881
+ __all__ = ["EuroBertPreTrainedModel", "EuroBertModel", "EuroBertForMaskedLM", "EuroBertForSequenceClassification"]