dav1dliu commited on
Commit
ce85381
·
verified ·
1 Parent(s): a1f43ea

Replace Qwen3 configs/modeling with SDAR version

Browse files
config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
  "architectures": [
3
- "Qwen3ForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_qwen3.Qwen3Config",
7
- "AutoModel": "modeling_qwen3.Qwen3Model",
8
- "AutoModelForCausalLM": "modeling_qwen3.Qwen3ForCausalLM"
9
  },
10
  "attention_bias": false,
11
  "attention_dropout": 0.0,
@@ -19,7 +19,7 @@
19
  "intermediate_size": 6144,
20
  "max_position_embeddings": 32768,
21
  "max_window_layers": 28,
22
- "model_type": "qwen3",
23
  "num_attention_heads": 16,
24
  "num_hidden_layers": 28,
25
  "num_key_value_heads": 8,
 
1
  {
2
  "architectures": [
3
+ "SDARForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_sdar.SDARConfig",
7
+ "AutoModel": "modeling_sdar.SDARModel",
8
+ "AutoModelForCausalLM": "modeling_sdar.SDARForCausalLM"
9
  },
10
  "attention_bias": false,
11
  "attention_dropout": 0.0,
 
19
  "intermediate_size": 6144,
20
  "max_position_embeddings": 32768,
21
  "max_window_layers": 28,
22
+ "model_type": "sdar",
23
  "num_attention_heads": 16,
24
  "num_hidden_layers": 28,
25
  "num_key_value_heads": 8,
configuration_qwen3.py → configuration_sdar.py RENAMED
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """Qwen3 model configuration"""
16
 
17
  from transformers.configuration_utils import PretrainedConfig
18
  from transformers.modeling_rope_utils import rope_config_validation
@@ -22,12 +22,12 @@ from transformers.utils import logging
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
- class Qwen3Config(PretrainedConfig):
26
  r"""
27
- This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
28
- Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
  with the defaults will yield a similar configuration to that of
30
- Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
31
 
32
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
  documentation from [`PretrainedConfig`] for more information.
@@ -35,8 +35,8 @@ class Qwen3Config(PretrainedConfig):
35
 
36
  Args:
37
  vocab_size (`int`, *optional*, defaults to 151936):
38
- Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
39
- `inputs_ids` passed when calling [`Qwen3Model`]
40
  hidden_size (`int`, *optional*, defaults to 4096):
41
  Dimension of the hidden representations.
42
  intermediate_size (`int`, *optional*, defaults to 22016):
@@ -118,22 +118,22 @@ class Qwen3Config(PretrainedConfig):
118
  The dropout ratio for the attention probabilities.
119
 
120
  ```python
121
- >>> from transformers import Qwen3Model, Qwen3Config
122
 
123
- >>> # Initializing a Qwen3 style configuration
124
- >>> configuration = Qwen3Config()
125
 
126
- >>> # Initializing a model from the Qwen3-8B style configuration
127
- >>> model = Qwen3Model(configuration)
128
 
129
  >>> # Accessing the model configuration
130
  >>> configuration = model.config
131
  ```"""
132
 
133
- model_type = "qwen3"
134
  keys_to_ignore_at_inference = ["past_key_values"]
135
 
136
- # Default tensor parallel plan for base model `Qwen3`
137
  base_model_tp_plan = {
138
  "layers.*.self_attn.q_proj": "colwise",
139
  "layers.*.self_attn.k_proj": "colwise",
@@ -209,4 +209,4 @@ class Qwen3Config(PretrainedConfig):
209
  )
210
 
211
 
212
- __all__ = ["Qwen3Config"]
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """SDAR model configuration"""
16
 
17
  from transformers.configuration_utils import PretrainedConfig
18
  from transformers.modeling_rope_utils import rope_config_validation
 
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
+ class SDARConfig(PretrainedConfig):
26
  r"""
27
+ This is the configuration class to store the configuration of a [`SDARModel`]. It is used to instantiate a
28
+ SDAR model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
  with the defaults will yield a similar configuration to that of
30
+ SDAR-1.7B [DiffuOpen/SDAR-1.7B-Chat](https://huggingface.co/DiffuOpen/SDAR-1.7B-Chat/).
31
 
32
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
  documentation from [`PretrainedConfig`] for more information.
 
35
 
36
  Args:
37
  vocab_size (`int`, *optional*, defaults to 151936):
38
+ Vocabulary size of the SDAR model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`SDARModel`]
40
  hidden_size (`int`, *optional*, defaults to 4096):
41
  Dimension of the hidden representations.
42
  intermediate_size (`int`, *optional*, defaults to 22016):
 
118
  The dropout ratio for the attention probabilities.
119
 
120
  ```python
121
+ >>> from transformers import SDARModel, SDARConfig
122
 
123
+ >>> # Initializing a SDAR style configuration
124
+ >>> configuration = SDARConfig()
125
 
126
+ >>> # Initializing a model from the SDAR-8B style configuration
127
+ >>> model = SDARModel(configuration)
128
 
129
  >>> # Accessing the model configuration
130
  >>> configuration = model.config
131
  ```"""
132
 
133
+ model_type = "sdar"
134
  keys_to_ignore_at_inference = ["past_key_values"]
135
 
136
+ # Default tensor parallel plan for base model `SDAR`
137
  base_model_tp_plan = {
138
  "layers.*.self_attn.q_proj": "colwise",
139
  "layers.*.self_attn.k_proj": "colwise",
 
209
  )
210
 
211
 
212
+ __all__ = ["SDARConfig"]
modeling_qwen3_origin.py DELETED
@@ -1,1065 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_qwen3.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # coding=utf-8
8
- # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
- #
10
- # Licensed under the Apache License, Version 2.0 (the "License");
11
- # you may not use this file except in compliance with the License.
12
- # You may obtain a copy of the License at
13
- #
14
- # http://www.apache.org/licenses/LICENSE-2.0
15
- #
16
- # Unless required by applicable law or agreed to in writing, software
17
- # distributed under the License is distributed on an "AS IS" BASIS,
18
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
- # See the License for the specific language governing permissions and
20
- # limitations under the License.
21
-
22
- from typing import Callable, Optional, Tuple, Union
23
-
24
- import torch
25
- from torch import nn
26
- from einops import rearrange
27
-
28
- from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
30
- from transformers.generation import GenerationMixin
31
- from transformers.integrations import use_kernel_forward_from_hub
32
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
33
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
- from transformers.modeling_layers import GradientCheckpointingLayer
35
- from transformers.modeling_outputs import (
36
- BaseModelOutputWithPast,
37
- CausalLMOutputWithPast,
38
- QuestionAnsweringModelOutput,
39
- SequenceClassifierOutputWithPast,
40
- TokenClassifierOutput,
41
- )
42
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
- from transformers.processing_utils import Unpack
45
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
46
- from .configuration_qwen3 import Qwen3Config
47
-
48
- from torch.nn import CrossEntropyLoss
49
- from fla.modules.activations import swiglu_linear
50
- from fla.modules import FusedLinearDiffusionCrossEntropyLoss
51
- from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
52
-
53
- if is_torch_flex_attn_available():
54
- from torch.nn.attention.flex_attention import BlockMask, flex_attention
55
-
56
- from transformers.integrations.flex_attention import make_flex_block_causal_mask
57
-
58
- # flex attn needs torch.compile to accelerate
59
- @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
60
- def fused_flex_attention(query, key, value, attention_mask, **kwargs):
61
- return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
62
-
63
- logger = logging.get_logger(__name__)
64
-
65
-
66
- @use_kernel_forward_from_hub("RMSNorm")
67
- class Qwen3RMSNorm(nn.Module):
68
- def __init__(self, hidden_size, eps=1e-6):
69
- """
70
- Qwen3RMSNorm is equivalent to T5LayerNorm
71
- """
72
- super().__init__()
73
- self.weight = nn.Parameter(torch.ones(hidden_size))
74
- self.variance_epsilon = eps
75
-
76
- def forward(self, hidden_states):
77
- input_dtype = hidden_states.dtype
78
- # hidden_states = hidden_states.to(torch.float32)
79
- # variance = hidden_states.pow(2).mean(-1, keepdim=True)
80
- # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
81
- return flash_rms_norm(
82
- x=hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon).to(input_dtype)
83
-
84
- def extra_repr(self):
85
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
86
-
87
-
88
- class Qwen3MLP(nn.Module):
89
- def __init__(self, config):
90
- super().__init__()
91
- self.config = config
92
- self.hidden_size = config.hidden_size
93
- self.intermediate_size = config.intermediate_size
94
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
95
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
96
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
97
- self.act_fn = ACT2FN[config.hidden_act]
98
-
99
- def forward(self, x):
100
- # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
101
- down_proj = swiglu_linear(self.gate_proj(x), self.up_proj(x),
102
- self.down_proj.weight, self.down_proj.bias)
103
- return down_proj
104
-
105
-
106
- def rotate_half(x):
107
- """Rotates half the hidden dims of the input."""
108
- x1 = x[..., : x.shape[-1] // 2]
109
- x2 = x[..., x.shape[-1] // 2 :]
110
- return torch.cat((-x2, x1), dim=-1)
111
-
112
-
113
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
114
- """Applies Rotary Position Embedding to the query and key tensors.
115
-
116
- Args:
117
- q (`torch.Tensor`): The query tensor.
118
- k (`torch.Tensor`): The key tensor.
119
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
120
- sin (`torch.Tensor`): The sine part of the rotary embedding.
121
- position_ids (`torch.Tensor`, *optional*):
122
- Deprecated and unused.
123
- unsqueeze_dim (`int`, *optional*, defaults to 1):
124
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
125
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
126
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
127
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
128
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
129
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
130
- Returns:
131
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
132
- """
133
- cos = cos.unsqueeze(unsqueeze_dim)
134
- sin = sin.unsqueeze(unsqueeze_dim)
135
- q_embed = (q * cos) + (rotate_half(q) * sin)
136
- k_embed = (k * cos) + (rotate_half(k) * sin)
137
- return q_embed, k_embed
138
-
139
-
140
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
141
- """
142
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
143
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
144
- """
145
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
146
- if n_rep == 1:
147
- return hidden_states
148
- hidden_states = hidden_states[:, :, None, :, :].expand(
149
- batch, num_key_value_heads, n_rep, slen, head_dim)
150
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
151
-
152
-
153
- def eager_attention_forward(
154
- module: nn.Module,
155
- query: torch.Tensor,
156
- key: torch.Tensor,
157
- value: torch.Tensor,
158
- attention_mask: Optional[torch.Tensor],
159
- scaling: float,
160
- dropout: float = 0.0,
161
- **kwargs,
162
- ):
163
- key_states = repeat_kv(key, module.num_key_value_groups)
164
- value_states = repeat_kv(value, module.num_key_value_groups)
165
-
166
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
167
- if attention_mask is not None:
168
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
169
- attn_weights = attn_weights + causal_mask
170
-
171
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
172
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
173
- attn_output = torch.matmul(attn_weights, value_states)
174
- attn_output = attn_output.transpose(1, 2).contiguous()
175
-
176
- return attn_output, attn_weights
177
-
178
-
179
- class Qwen3Attention(nn.Module):
180
- """Multi-headed attention from 'Attention Is All You Need' paper"""
181
-
182
- def __init__(self, config: Qwen3Config, layer_idx: int):
183
- super().__init__()
184
- self.config = config
185
- self.layer_idx = layer_idx
186
- self.num_attention_heads = config.num_attention_heads
187
- self.num_key_value_heads = config.num_key_value_heads
188
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
189
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
190
- self.scaling = self.head_dim**-0.5
191
- self.attention_dropout = config.attention_dropout
192
- self.is_causal = False
193
-
194
- self.q_proj = nn.Linear(
195
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
196
- )
197
- self.k_proj = nn.Linear(
198
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
199
- )
200
- self.v_proj = nn.Linear(
201
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
202
- )
203
- self.o_proj = nn.Linear(
204
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
205
- )
206
- self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
207
- self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
208
- self.sliding_window = config.sliding_window
209
- if not (
210
- self.config.use_sliding_window
211
- and getattr(self.config, "sliding_window", None) is not None
212
- and self.layer_idx >= self.config.max_window_layers
213
- ):
214
- self.sliding_window = None
215
-
216
- def forward(
217
- self,
218
- hidden_states: torch.Tensor,
219
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
220
- attention_mask: Optional[torch.Tensor],
221
- past_key_value: Optional[Cache] = None,
222
- cache_position: Optional[torch.LongTensor] = None,
223
- **kwargs: Unpack[FlashAttentionKwargs],
224
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
- input_shape = hidden_states.shape[:-1]
226
- bsz, q_len = input_shape
227
- hidden_shape = (*input_shape, -1, self.head_dim)
228
-
229
- query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
230
- key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
231
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
232
-
233
- cos, sin = position_embeddings
234
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
235
-
236
- if past_key_value is not None:
237
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
238
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
239
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
240
-
241
- attention_interface: Callable = eager_attention_forward
242
- if self.config._attn_implementation != "eager":
243
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
244
- logger.warning_once(
245
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
246
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
247
- )
248
- else:
249
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
250
-
251
- if self.config._attn_implementation == 'flex_attention':
252
- # Although there exists `flex_attention_forward` in `AttentionInterface`,
253
- # we still use our customized `fused_flex_attention` for debugging.
254
- pad_length = kwargs.get("pad_length", 0)
255
- # Used for SFT (packing + varlen), seq_len changes at each step
256
- # seq_len must be divisible by BLOCK_SIZE in flex attn
257
- pad_q = torch.zeros(
258
- bsz, self.num_attention_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
259
- pad_kv = torch.zeros(
260
- bsz, self.num_key_value_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
261
- attn_output, attn_weights = fused_flex_attention(
262
- query=torch.cat([query_states, pad_q], dim=2),
263
- key=torch.cat([key_states, pad_kv], dim=2),
264
- value=torch.cat([value_states, pad_kv], dim=2),
265
- attention_mask=attention_mask,
266
- enable_gqa=True,
267
- scale=self.scaling,
268
- return_lse=True
269
- )
270
-
271
- attn_output = attn_output[..., :q_len, :].contiguous()
272
- attn_weights = attn_weights.to(value_states.dtype)
273
- attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
274
- else:
275
- attn_output, attn_weights = attention_interface(
276
- self,
277
- query_states,
278
- key_states,
279
- value_states,
280
- attention_mask,
281
- dropout=0.0 if not self.training else self.attention_dropout,
282
- scaling=self.scaling,
283
- sliding_window=self.sliding_window, # diff with Llama
284
- **kwargs,
285
- ) # output: [b, l, h, d]
286
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
287
- attn_output = self.o_proj(attn_output)
288
- return attn_output, attn_weights
289
-
290
-
291
- class Qwen3DecoderLayer(GradientCheckpointingLayer):
292
- def __init__(self, config: Qwen3Config, layer_idx: int):
293
- super().__init__()
294
- self.hidden_size = config.hidden_size
295
- self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
296
- self.mlp = Qwen3MLP(config)
297
- self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
298
- self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
- if (
300
- config.sliding_window and config._attn_implementation != "flash_attention_2"
301
- ): # diff with Llama is this warning
302
- logger.warning_once(
303
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
304
- "unexpected results may be encountered."
305
- )
306
-
307
- def forward(
308
- self,
309
- hidden_states: torch.Tensor,
310
- attention_mask: Optional[torch.Tensor] = None,
311
- position_ids: Optional[torch.LongTensor] = None,
312
- past_key_value: Optional[Cache] = None,
313
- output_attentions: Optional[bool] = False,
314
- use_cache: Optional[bool] = False,
315
- cache_position: Optional[torch.LongTensor] = None,
316
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
317
- **kwargs: Unpack[FlashAttentionKwargs],
318
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
319
- residual = hidden_states
320
- hidden_states = self.input_layernorm(hidden_states)
321
-
322
- # Self Attention
323
- hidden_states, self_attn_weights = self.self_attn(
324
- hidden_states=hidden_states,
325
- attention_mask=attention_mask,
326
- position_ids=position_ids,
327
- past_key_value=past_key_value,
328
- output_attentions=output_attentions,
329
- use_cache=use_cache,
330
- cache_position=cache_position,
331
- position_embeddings=position_embeddings,
332
- **kwargs,
333
- )
334
- hidden_states = residual + hidden_states
335
-
336
- # Fully Connected
337
- residual = hidden_states
338
- hidden_states = self.post_attention_layernorm(hidden_states)
339
- hidden_states = self.mlp(hidden_states)
340
- hidden_states = residual + hidden_states
341
-
342
- outputs = (hidden_states,)
343
- if output_attentions:
344
- outputs += (self_attn_weights,)
345
-
346
- return outputs
347
-
348
-
349
- @auto_docstring
350
- class Qwen3PreTrainedModel(PreTrainedModel):
351
- config_class = Qwen3Config
352
- base_model_prefix = "model"
353
- supports_gradient_checkpointing = True
354
- _no_split_modules = ["Qwen3DecoderLayer"]
355
- _skip_keys_device_placement = ["past_key_values"]
356
- _supports_flash_attn_2 = True
357
- _supports_sdpa = True
358
- _supports_flex_attn = True
359
- _supports_cache_class = True
360
- _supports_quantized_cache = True
361
- _supports_static_cache = True
362
- _supports_attention_backend = True
363
-
364
- def _init_weights(self, module):
365
- std = self.config.initializer_range
366
- if isinstance(module, nn.Linear):
367
- module.weight.data.normal_(mean=0.0, std=std)
368
- if module.bias is not None:
369
- module.bias.data.zero_()
370
- elif isinstance(module, nn.Embedding):
371
- module.weight.data.normal_(mean=0.0, std=std)
372
- if module.padding_idx is not None:
373
- module.weight.data[module.padding_idx].zero_()
374
- elif isinstance(module, Qwen3RMSNorm):
375
- module.weight.data.fill_(1.0)
376
-
377
-
378
- class Qwen3RotaryEmbedding(nn.Module):
379
- def __init__(self, config: Qwen3Config, device=None):
380
- super().__init__()
381
- # BC: "rope_type" was originally "type"
382
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
383
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
384
- else:
385
- self.rope_type = "default"
386
- self.max_seq_len_cached = config.max_position_embeddings
387
- self.original_max_seq_len = config.max_position_embeddings
388
-
389
- self.config = config
390
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
391
-
392
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
393
- self.register_buffer("inv_freq", inv_freq, persistent=False)
394
- self.original_inv_freq = self.inv_freq
395
-
396
- @torch.no_grad()
397
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
398
- def forward(self, x, position_ids):
399
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
400
- position_ids_expanded = position_ids[:, None, :].float()
401
-
402
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
403
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
404
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
405
- emb = torch.cat((freqs, freqs), dim=-1)
406
- cos = emb.cos() * self.attention_scaling
407
- sin = emb.sin() * self.attention_scaling
408
-
409
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
410
-
411
-
412
- @auto_docstring
413
- class Qwen3Model(Qwen3PreTrainedModel):
414
- def __init__(self, config: Qwen3Config):
415
- super().__init__(config)
416
- self.padding_idx = config.pad_token_id
417
- self.vocab_size = config.vocab_size
418
-
419
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
420
- self.layers = nn.ModuleList(
421
- [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
422
- )
423
- self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
424
- self.rotary_emb = Qwen3RotaryEmbedding(config=config)
425
- self.gradient_checkpointing = False
426
-
427
- # Initialize weights and apply final processing
428
- self.post_init()
429
-
430
- def get_input_embeddings(self):
431
- return self.embed_tokens
432
-
433
- def set_input_embeddings(self, value):
434
- self.embed_tokens = value
435
-
436
- @can_return_tuple
437
- @auto_docstring
438
- def forward(
439
- self,
440
- input_ids: Optional[torch.LongTensor] = None,
441
- attention_mask: Optional[torch.Tensor] = None,
442
- position_ids: Optional[torch.LongTensor] = None,
443
- past_key_values: Optional[Cache] = None,
444
- inputs_embeds: Optional[torch.FloatTensor] = None,
445
- use_cache: Optional[bool] = None,
446
- output_attentions: Optional[bool] = None,
447
- output_hidden_states: Optional[bool] = None,
448
- cache_position: Optional[torch.LongTensor] = None,
449
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
450
- ) -> BaseModelOutputWithPast:
451
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
452
- output_hidden_states = (
453
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
454
- )
455
- use_cache = use_cache if use_cache is not None else self.config.use_cache
456
-
457
- if (input_ids is None) ^ (inputs_embeds is not None):
458
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
459
-
460
- if self.gradient_checkpointing and self.training and use_cache:
461
- logger.warning_once(
462
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
463
- )
464
- use_cache = False
465
-
466
- # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
467
- if not isinstance(past_key_values, (type(None), Cache)):
468
- raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
469
-
470
- if inputs_embeds is None:
471
- inputs_embeds = self.embed_tokens(input_ids)
472
-
473
- if use_cache and past_key_values is None:
474
- past_key_values = DynamicCache()
475
-
476
- if cache_position is None:
477
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
478
- cache_position = torch.arange(
479
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
480
- )
481
-
482
- if position_ids is None:
483
- position_ids = cache_position.unsqueeze(0)
484
-
485
- causal_mask = self._update_causal_mask(
486
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
487
- )
488
-
489
- hidden_states = inputs_embeds
490
-
491
- # create position embeddings to be shared across the decoder layers
492
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
493
-
494
- # decoder layers
495
- all_hidden_states = () if output_hidden_states else None
496
- all_self_attns = () if output_attentions else None
497
-
498
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
499
- if output_hidden_states:
500
- all_hidden_states += (hidden_states,)
501
-
502
- layer_outputs = decoder_layer(
503
- hidden_states,
504
- attention_mask=causal_mask,
505
- position_ids=position_ids,
506
- past_key_value=past_key_values,
507
- output_attentions=output_attentions,
508
- use_cache=use_cache,
509
- cache_position=cache_position,
510
- position_embeddings=position_embeddings,
511
- **flash_attn_kwargs,
512
- )
513
-
514
- hidden_states = layer_outputs[0]
515
-
516
- if output_attentions:
517
- all_self_attns += (layer_outputs[1],)
518
-
519
- hidden_states = self.norm(hidden_states)
520
-
521
- # add hidden states from the last decoder layer
522
- if output_hidden_states:
523
- all_hidden_states += (hidden_states,)
524
-
525
- return BaseModelOutputWithPast(
526
- last_hidden_state=hidden_states,
527
- past_key_values=past_key_values if use_cache else None,
528
- hidden_states=all_hidden_states,
529
- attentions=all_self_attns,
530
- )
531
-
532
- def _update_causal_mask(
533
- self,
534
- attention_mask: Union[torch.Tensor, "BlockMask"],
535
- input_tensor: torch.Tensor,
536
- cache_position: torch.Tensor,
537
- past_key_values: Cache,
538
- output_attentions: bool = False,
539
- ):
540
- if self.config._attn_implementation == "flash_attention_2":
541
- if attention_mask is not None and past_key_values is not None:
542
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
543
- if is_padding_right:
544
- raise ValueError(
545
- "You are attempting to perform batched generation with padding_side='right'"
546
- " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
547
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
548
- )
549
- if attention_mask is not None and 0.0 in attention_mask:
550
- return attention_mask
551
- return None
552
- if self.config._attn_implementation == "flex_attention":
553
- # Use flex block mask directly
554
- assert isinstance(attention_mask, BlockMask)
555
- return attention_mask
556
-
557
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
558
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
559
- # to infer the attention mask.
560
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
561
- using_static_cache = isinstance(past_key_values, StaticCache)
562
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
563
-
564
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
565
- if (
566
- self.config._attn_implementation == "sdpa"
567
- and not (using_static_cache or using_sliding_window_cache)
568
- and not output_attentions
569
- ):
570
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
571
- attention_mask,
572
- inputs_embeds=input_tensor,
573
- past_key_values_length=past_seen_tokens,
574
- sliding_window=self.config.sliding_window,
575
- is_training=self.training,
576
- ):
577
- return None
578
-
579
- dtype = input_tensor.dtype
580
- min_dtype = torch.finfo(dtype).min
581
- sequence_length = input_tensor.shape[1]
582
- # SlidingWindowCache or StaticCache
583
- if using_sliding_window_cache or using_static_cache:
584
- target_length = past_key_values.get_max_cache_shape()
585
- # DynamicCache or no cache
586
- else:
587
- target_length = (
588
- attention_mask.shape[-1]
589
- if isinstance(attention_mask, torch.Tensor)
590
- else past_seen_tokens + sequence_length + 1
591
- )
592
-
593
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
594
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
595
- attention_mask,
596
- sequence_length=sequence_length,
597
- target_length=target_length,
598
- dtype=dtype,
599
- cache_position=cache_position,
600
- batch_size=input_tensor.shape[0],
601
- config=self.config,
602
- past_key_values=past_key_values,
603
- )
604
-
605
- if (
606
- self.config._attn_implementation == "sdpa"
607
- and attention_mask is not None
608
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
609
- and not output_attentions
610
- ):
611
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
612
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
613
- # Details: https://github.com/pytorch/pytorch/issues/110213
614
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
615
-
616
- return causal_mask
617
-
618
- @staticmethod
619
- def _prepare_4d_causal_attention_mask_with_cache_position(
620
- attention_mask: torch.Tensor,
621
- sequence_length: int,
622
- target_length: int,
623
- dtype: torch.dtype,
624
- cache_position: torch.Tensor,
625
- batch_size: int,
626
- config: Qwen3Config,
627
- past_key_values: Cache,
628
- ):
629
- """
630
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
631
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
632
-
633
- Args:
634
- attention_mask (`torch.Tensor`):
635
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
636
- sequence_length (`int`):
637
- The sequence length being processed.
638
- target_length (`int`):
639
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
640
- dtype (`torch.dtype`):
641
- The dtype to use for the 4D attention mask.
642
- cache_position (`torch.Tensor`):
643
- Indices depicting the position of the input sequence tokens in the sequence.
644
- batch_size (`torch.Tensor`):
645
- Batch size.
646
- config (`Qwen3Config`):
647
- The model's configuration class
648
- past_key_values (`Cache`):
649
- The cache class that is being used currently to generate
650
- """
651
- if attention_mask is not None and attention_mask.dim() == 4:
652
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
653
- causal_mask = attention_mask
654
- else:
655
- min_dtype = torch.finfo(dtype).min
656
- causal_mask = torch.full(
657
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
658
- )
659
- diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
660
- -1, 1
661
- )
662
- text_config = config.get_text_config()
663
- if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
664
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
665
- # the check is needed to verify is current checkpoint was trained with sliding window or not
666
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
667
- sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
668
- cache_position.reshape(-1, 1) - text_config.sliding_window
669
- )
670
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
671
- causal_mask *= diagonal_attend_mask
672
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
673
- if attention_mask is not None:
674
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
675
- if attention_mask.shape[-1] > target_length:
676
- attention_mask = attention_mask[:, :target_length]
677
- mask_length = attention_mask.shape[-1]
678
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
679
- causal_mask.device
680
- )
681
- padding_mask = padding_mask == 0
682
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
683
- padding_mask, min_dtype
684
- )
685
- return causal_mask
686
-
687
-
688
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
689
-
690
-
691
- @auto_docstring
692
- class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
693
- _tied_weights_keys = ["lm_head.weight"]
694
- _tp_plan = {"lm_head": "colwise_rep"}
695
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
696
-
697
- def __init__(self, config):
698
- super().__init__(config)
699
- self.model = Qwen3Model(config)
700
- self.vocab_size = config.vocab_size
701
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
702
-
703
- # Initialize weights and apply final processing
704
- self.post_init()
705
-
706
- def get_input_embeddings(self):
707
- return self.model.embed_tokens
708
-
709
- def set_input_embeddings(self, value):
710
- self.model.embed_tokens = value
711
-
712
- def get_output_embeddings(self):
713
- return self.lm_head
714
-
715
- def set_output_embeddings(self, new_embeddings):
716
- self.lm_head = new_embeddings
717
-
718
- def set_decoder(self, decoder):
719
- self.model = decoder
720
-
721
- def get_decoder(self):
722
- return self.model
723
-
724
- @can_return_tuple
725
- @auto_docstring
726
- def forward(
727
- self,
728
- input_ids: Optional[torch.LongTensor] = None,
729
- attention_mask: Optional[torch.Tensor] = None,
730
- position_ids: Optional[torch.LongTensor] = None,
731
- past_key_values: Optional[Cache] = None,
732
- inputs_embeds: Optional[torch.FloatTensor] = None,
733
- labels: Optional[torch.LongTensor] = None,
734
- use_cache: Optional[bool] = None,
735
- output_attentions: Optional[bool] = None,
736
- output_hidden_states: Optional[bool] = None,
737
- cache_position: Optional[torch.LongTensor] = None,
738
- logits_to_keep: Union[int, torch.Tensor] = 0,
739
- **kwargs: Unpack[KwargsForCausalLM],
740
- ) -> CausalLMOutputWithPast:
741
- r"""
742
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
743
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
744
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
745
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
746
-
747
- Example:
748
-
749
- ```python
750
- >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
751
-
752
- >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
753
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
754
-
755
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
756
- >>> inputs = tokenizer(prompt, return_tensors="pt")
757
-
758
- >>> # Generate
759
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
760
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
761
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
762
- ```"""
763
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
764
- output_hidden_states = (
765
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
766
- )
767
-
768
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
769
- outputs: BaseModelOutputWithPast = self.model(
770
- input_ids=input_ids,
771
- attention_mask=attention_mask,
772
- position_ids=position_ids,
773
- past_key_values=past_key_values,
774
- inputs_embeds=inputs_embeds,
775
- use_cache=use_cache,
776
- output_attentions=output_attentions,
777
- output_hidden_states=output_hidden_states,
778
- cache_position=cache_position,
779
- **kwargs,
780
- )
781
-
782
- hidden_states = outputs.last_hidden_state
783
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
784
- logits_to_keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
785
- hidden_states = hidden_states[:, logits_to_keep, :].contiguous()
786
- fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
787
- if fuse_linear_and_cross_entropy:
788
- logits = None
789
- else:
790
- logits = self.lm_head(hidden_states)
791
-
792
- loss = None
793
- if labels is not None:
794
- if fuse_linear_and_cross_entropy:
795
- loss_fct = FusedLinearDiffusionCrossEntropyLoss(
796
- reduction='sum')
797
- else:
798
- loss_fct = CrossEntropyLoss() # nn.CE
799
-
800
- # you don't have to shift labels
801
- # labels = labels.to(hidden_states.device)
802
- # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
803
- if fuse_linear_and_cross_entropy:
804
- loss = loss_fct( # it will return (sum_loss, unreduced_loss)
805
- x=hidden_states, # conduct `view(-1, V)` inside the function
806
- target=labels,
807
- weight=self.lm_head.weight,
808
- bias=self.lm_head.bias,
809
- p_mask=kwargs['p_mask'],
810
- )
811
- else:
812
- loss = loss_fct(
813
- logits.view(-1, self.config.vocab_size), labels.view(-1))
814
-
815
- return CausalLMOutputWithPast(
816
- loss=loss,
817
- logits=logits,
818
- past_key_values=outputs.past_key_values,
819
- hidden_states=outputs.hidden_states,
820
- attentions=outputs.attentions,
821
- )
822
-
823
-
824
- @auto_docstring(
825
- custom_intro="""
826
- The Qwen3 Model transformer with a sequence classification head on top (linear layer).
827
-
828
- [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
829
- (e.g. GPT-2) do.
830
-
831
- Since it does classification on the last token, it requires to know the position of the last token. If a
832
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
833
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
834
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
835
- each row of the batch).
836
- """
837
- )
838
- class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
839
- def __init__(self, config):
840
- super().__init__(config)
841
- self.num_labels = config.num_labels
842
- self.model = Qwen3Model(config)
843
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
844
-
845
- # Initialize weights and apply final processing
846
- self.post_init()
847
-
848
- def get_input_embeddings(self):
849
- return self.model.embed_tokens
850
-
851
- def set_input_embeddings(self, value):
852
- self.model.embed_tokens = value
853
-
854
- @can_return_tuple
855
- @auto_docstring
856
- def forward(
857
- self,
858
- input_ids: Optional[torch.LongTensor] = None,
859
- attention_mask: Optional[torch.Tensor] = None,
860
- position_ids: Optional[torch.LongTensor] = None,
861
- past_key_values: Optional[Cache] = None,
862
- inputs_embeds: Optional[torch.FloatTensor] = None,
863
- labels: Optional[torch.LongTensor] = None,
864
- use_cache: Optional[bool] = None,
865
- output_attentions: Optional[bool] = None,
866
- output_hidden_states: Optional[bool] = None,
867
- ) -> SequenceClassifierOutputWithPast:
868
- r"""
869
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
870
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
871
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
872
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
873
- """
874
-
875
- transformer_outputs: BaseModelOutputWithPast = self.model(
876
- input_ids,
877
- attention_mask=attention_mask,
878
- position_ids=position_ids,
879
- past_key_values=past_key_values,
880
- inputs_embeds=inputs_embeds,
881
- use_cache=use_cache,
882
- output_attentions=output_attentions,
883
- output_hidden_states=output_hidden_states,
884
- )
885
- hidden_states = transformer_outputs.last_hidden_state
886
- logits = self.score(hidden_states)
887
-
888
- if input_ids is not None:
889
- batch_size = input_ids.shape[0]
890
- else:
891
- batch_size = inputs_embeds.shape[0]
892
-
893
- if self.config.pad_token_id is None and batch_size != 1:
894
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
895
- if self.config.pad_token_id is None:
896
- last_non_pad_token = -1
897
- elif input_ids is not None:
898
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
899
- non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
900
- token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
901
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
902
- else:
903
- last_non_pad_token = -1
904
- logger.warning_once(
905
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
906
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
907
- )
908
-
909
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
910
-
911
- loss = None
912
- if labels is not None:
913
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
914
-
915
- return SequenceClassifierOutputWithPast(
916
- loss=loss,
917
- logits=pooled_logits,
918
- past_key_values=transformer_outputs.past_key_values,
919
- hidden_states=transformer_outputs.hidden_states,
920
- attentions=transformer_outputs.attentions,
921
- )
922
-
923
-
924
- @auto_docstring
925
- class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
926
- def __init__(self, config):
927
- super().__init__(config)
928
- self.num_labels = config.num_labels
929
- self.model = Qwen3Model(config)
930
- if getattr(config, "classifier_dropout", None) is not None:
931
- classifier_dropout = config.classifier_dropout
932
- elif getattr(config, "hidden_dropout", None) is not None:
933
- classifier_dropout = config.hidden_dropout
934
- else:
935
- classifier_dropout = 0.1
936
- self.dropout = nn.Dropout(classifier_dropout)
937
- self.score = nn.Linear(config.hidden_size, config.num_labels)
938
-
939
- # Initialize weights and apply final processing
940
- self.post_init()
941
-
942
- def get_input_embeddings(self):
943
- return self.model.embed_tokens
944
-
945
- def set_input_embeddings(self, value):
946
- self.model.embed_tokens = value
947
-
948
- @can_return_tuple
949
- @auto_docstring
950
- def forward(
951
- self,
952
- input_ids: Optional[torch.LongTensor] = None,
953
- attention_mask: Optional[torch.Tensor] = None,
954
- position_ids: Optional[torch.LongTensor] = None,
955
- past_key_values: Optional[Cache] = None,
956
- inputs_embeds: Optional[torch.FloatTensor] = None,
957
- labels: Optional[torch.LongTensor] = None,
958
- use_cache: Optional[bool] = None,
959
- output_attentions: Optional[bool] = None,
960
- output_hidden_states: Optional[bool] = None,
961
- ) -> TokenClassifierOutput:
962
- r"""
963
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
964
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
965
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
966
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
967
- """
968
-
969
- outputs: BaseModelOutputWithPast = self.model(
970
- input_ids,
971
- attention_mask=attention_mask,
972
- position_ids=position_ids,
973
- past_key_values=past_key_values,
974
- inputs_embeds=inputs_embeds,
975
- use_cache=use_cache,
976
- output_attentions=output_attentions,
977
- output_hidden_states=output_hidden_states,
978
- )
979
- sequence_output = outputs.last_hidden_state
980
- sequence_output = self.dropout(sequence_output)
981
- logits = self.score(sequence_output)
982
-
983
- loss = None
984
- if labels is not None:
985
- loss = self.loss_function(logits, labels, self.config)
986
-
987
- return TokenClassifierOutput(
988
- loss=loss,
989
- logits=logits,
990
- hidden_states=outputs.hidden_states,
991
- attentions=outputs.attentions,
992
- )
993
-
994
-
995
- @auto_docstring
996
- class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
997
- base_model_prefix = "transformer"
998
-
999
- def __init__(self, config):
1000
- super().__init__(config)
1001
- self.transformer = Qwen3Model(config)
1002
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1003
-
1004
- # Initialize weights and apply final processing
1005
- self.post_init()
1006
-
1007
- def get_input_embeddings(self):
1008
- return self.transformer.embed_tokens
1009
-
1010
- def set_input_embeddings(self, value):
1011
- self.transformer.embed_tokens = value
1012
-
1013
- @can_return_tuple
1014
- @auto_docstring
1015
- def forward(
1016
- self,
1017
- input_ids: Optional[torch.LongTensor] = None,
1018
- attention_mask: Optional[torch.Tensor] = None,
1019
- position_ids: Optional[torch.LongTensor] = None,
1020
- past_key_values: Optional[Cache] = None,
1021
- inputs_embeds: Optional[torch.FloatTensor] = None,
1022
- start_positions: Optional[torch.LongTensor] = None,
1023
- end_positions: Optional[torch.LongTensor] = None,
1024
- output_attentions: Optional[bool] = None,
1025
- output_hidden_states: Optional[bool] = None,
1026
- **kwargs,
1027
- ) -> QuestionAnsweringModelOutput:
1028
- outputs: BaseModelOutputWithPast = self.transformer(
1029
- input_ids,
1030
- attention_mask=attention_mask,
1031
- position_ids=position_ids,
1032
- past_key_values=past_key_values,
1033
- inputs_embeds=inputs_embeds,
1034
- output_attentions=output_attentions,
1035
- output_hidden_states=output_hidden_states,
1036
- )
1037
-
1038
- sequence_output = outputs.last_hidden_state
1039
-
1040
- logits = self.qa_outputs(sequence_output)
1041
- start_logits, end_logits = logits.split(1, dim=-1)
1042
- start_logits = start_logits.squeeze(-1).contiguous()
1043
- end_logits = end_logits.squeeze(-1).contiguous()
1044
-
1045
- loss = None
1046
- if start_positions is not None and end_positions is not None:
1047
- loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1048
-
1049
- return QuestionAnsweringModelOutput(
1050
- loss=loss,
1051
- start_logits=start_logits,
1052
- end_logits=end_logits,
1053
- hidden_states=outputs.hidden_states,
1054
- attentions=outputs.attentions,
1055
- )
1056
-
1057
-
1058
- __all__ = [
1059
- "Qwen3ForCausalLM",
1060
- "Qwen3ForQuestionAnswering",
1061
- "Qwen3Model",
1062
- "Qwen3PreTrainedModel",
1063
- "Qwen3ForSequenceClassification",
1064
- "Qwen3ForTokenClassification",
1065
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_qwen3.py → modeling_sdar.py RENAMED
@@ -1,3 +1,5 @@
 
 
1
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
  # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
  # Do NOT edit this file manually as any edits will be overwritten by the generation of
@@ -42,7 +44,7 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
42
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
  from transformers.processing_utils import Unpack
44
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
45
- from .configuration_qwen3 import Qwen3Config
46
 
47
  from fla.modules.activations import swiglu_linear
48
  from fla.modules import (
@@ -81,10 +83,10 @@ def fused_flex_attention(query, key, value, attention_mask=None, **kwargs):
81
 
82
 
83
  @use_kernel_forward_from_hub("RMSNorm")
84
- class Qwen3RMSNorm(nn.Module):
85
  def __init__(self, hidden_size, eps=1e-6):
86
  """
87
- Qwen3RMSNorm is equivalent to T5LayerNorm
88
  """
89
  super().__init__()
90
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -107,7 +109,7 @@ class Qwen3RMSNorm(nn.Module):
107
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
108
 
109
 
110
- class Qwen3MLP(nn.Module):
111
  def __init__(self, config):
112
  super().__init__()
113
  self.config = config
@@ -205,10 +207,10 @@ def eager_attention_forward(
205
  return attn_output, attn_weights
206
 
207
 
208
- class Qwen3Attention(nn.Module):
209
  """Multi-headed attention from 'Attention Is All You Need' paper"""
210
 
211
- def __init__(self, config: Qwen3Config, layer_idx: int):
212
  super().__init__()
213
  self.config = config
214
  self.layer_idx = layer_idx
@@ -236,9 +238,9 @@ class Qwen3Attention(nn.Module):
236
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
237
  )
238
  # unlike olmo, only on the head dim!
239
- self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
240
  # thus post q_norm does not need reshape
241
- self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
242
  self.sliding_window = config.sliding_window
243
  if not (
244
  self.config.use_sliding_window
@@ -275,7 +277,8 @@ class Qwen3Attention(nn.Module):
275
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
276
  key_states, value_states = past_key_value.update(
277
  key_states, value_states, self.layer_idx)
278
- elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:# 只取不存
 
279
  past_key_states, past_value_states = past_key_value[self.layer_idx]
280
  key_states = torch.cat(
281
  [past_key_states, key_states], dim=-2
@@ -283,75 +286,9 @@ class Qwen3Attention(nn.Module):
283
  value_states = torch.cat(
284
  [past_value_states, value_states], dim=-2
285
  )
286
- # if past_key_value is not None:
287
- # # sin and cos are specific to RoPE models; cache_position needed for the static cache
288
- # cache_kwargs = {"sin": sin, "cos": cos,
289
- # "cache_position": cache_position}
290
- # key_states, value_states = past_key_value.update(
291
- # key_states, value_states, self.layer_idx, cache_kwargs)
292
-
293
- # attention_interface: Callable = eager_attention_forward
294
- # if self.config._attn_implementation != "eager":
295
- # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
296
- # logger.warning_once(
297
- # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
298
- # 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
299
- # )
300
- # else:
301
- # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
302
-
303
- # if self.config._attn_implementation == 'flex_attention':
304
- # # Although `AttentionInterface` has `flex_attention_forward` implementation,
305
- # # we still use our customized `fused_flex_attention`
306
- # pad_length = kwargs.get("pad_length", None)
307
- # if pad_length is not None:
308
- # # Used for SFT (packing + varlen), seq_len changes at each step
309
- # # seq_len must be divisible by BLOCK_SIZE in flex attn
310
- # pad_q = torch.zeros(
311
- # bsz, self.num_attention_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
312
- # pad_kv = torch.zeros(
313
- # bsz, self.num_key_value_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
314
- # attn_output, attn_weights = fused_flex_attention(
315
- # query=torch.cat([query_states, pad_q], dim=2),
316
- # key=torch.cat([key_states, pad_kv], dim=2),
317
- # value=torch.cat([value_states, pad_kv], dim=2),
318
- # attention_mask=attention_mask,
319
- # enable_gqa=True,
320
- # scale=self.scaling,
321
- # return_lse=True
322
- # )
323
- # attn_output = attn_output[..., :q_len,
324
- # :].transpose(1, 2).contiguous()
325
- # attn_weights = attn_weights.to(value_states.dtype)
326
- # else:
327
- # attn_output, attn_weights = fused_flex_attention(
328
- # query=query_states,
329
- # key=key_states,
330
- # value=value_states,
331
- # attention_mask=attention_mask,
332
- # enable_gqa=True,
333
- # scale=self.scaling,
334
- # return_lse=True
335
- # )
336
- # attn_output = attn_output.transpose(1, 2).contiguous()
337
- # attn_weights = attn_weights.to(value_states.dtype)
338
- # else:
339
- # attn_output, attn_weights = attention_interface(
340
- # self,
341
- # query_states,
342
- # key_states,
343
- # value_states,
344
- # attention_mask,
345
- # dropout=0.0 if not self.training else self.attention_dropout,
346
- # scaling=self.scaling,
347
- # sliding_window=self.sliding_window, # diff with Llama
348
- # **kwargs,
349
- # )
350
- # q: (b, h, l, d); k,v: (b, h', l, d); attn_output: (b, l, h, d);
351
- # key_states = repeat_kv(key_states, 2)
352
- # value_states = repeat_kv(value_states, 2)
353
  attention_mask = attention_mask.bool() if attention_mask is not None else None
354
- if torch.all(attention_mask): # 属于 decoding 阶段
355
  query_states = query_states.transpose(1, 2)
356
  key_states = key_states.transpose(1, 2)
357
  value_states = value_states.transpose(1, 2)
@@ -362,7 +299,7 @@ class Qwen3Attention(nn.Module):
362
  causal=False,
363
  softmax_scale=self.scaling)
364
 
365
- else:
366
  attn_output = F.scaled_dot_product_attention(
367
  query=query_states,
368
  key=key_states,
@@ -378,15 +315,15 @@ class Qwen3Attention(nn.Module):
378
  return attn_output, None #, attn_weights
379
 
380
 
381
- class Qwen3DecoderLayer(GradientCheckpointingLayer):
382
- def __init__(self, config: Qwen3Config, layer_idx: int):
383
  super().__init__()
384
  self.hidden_size = config.hidden_size
385
- self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
386
- self.mlp = Qwen3MLP(config)
387
- self.input_layernorm = Qwen3RMSNorm(
388
  config.hidden_size, eps=config.rms_norm_eps)
389
- self.post_attention_layernorm = Qwen3RMSNorm(
390
  config.hidden_size, eps=config.rms_norm_eps)
391
  if (
392
  config.sliding_window and config._attn_implementation != "flash_attention_2"
@@ -443,11 +380,11 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer):
443
 
444
 
445
  @auto_docstring
446
- class Qwen3PreTrainedModel(PreTrainedModel):
447
- config_class = Qwen3Config
448
  base_model_prefix = "model"
449
  supports_gradient_checkpointing = True
450
- _no_split_modules = ["Qwen3DecoderLayer"]
451
  _skip_keys_device_placement = ["past_key_values"]
452
  _supports_flash_attn_2 = True
453
  _supports_sdpa = True
@@ -467,12 +404,12 @@ class Qwen3PreTrainedModel(PreTrainedModel):
467
  module.weight.data.normal_(mean=0.0, std=std)
468
  if module.padding_idx is not None:
469
  module.weight.data[module.padding_idx].zero_()
470
- elif isinstance(module, Qwen3RMSNorm):
471
  module.weight.data.fill_(1.0)
472
 
473
 
474
- class Qwen3RotaryEmbedding(nn.Module):
475
- def __init__(self, config: Qwen3Config, device=None):
476
  super().__init__()
477
  # BC: "rope_type" was originally "type"
478
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@@ -512,8 +449,8 @@ class Qwen3RotaryEmbedding(nn.Module):
512
 
513
 
514
  @auto_docstring
515
- class Qwen3Model(Qwen3PreTrainedModel):
516
- def __init__(self, config: Qwen3Config):
517
  super().__init__(config)
518
  self.padding_idx = config.pad_token_id
519
  self.vocab_size = config.vocab_size
@@ -521,11 +458,11 @@ class Qwen3Model(Qwen3PreTrainedModel):
521
  self.embed_tokens = nn.Embedding(
522
  config.vocab_size, config.hidden_size, self.padding_idx)
523
  self.layers = nn.ModuleList(
524
- [Qwen3DecoderLayer(config, layer_idx)
525
  for layer_idx in range(config.num_hidden_layers)]
526
  )
527
- self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
528
- self.rotary_emb = Qwen3RotaryEmbedding(config=config)
529
  self.gradient_checkpointing = False
530
 
531
  # Initialize weights and apply final processing
@@ -745,7 +682,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
745
  dtype: torch.dtype,
746
  cache_position: torch.Tensor,
747
  batch_size: int,
748
- config: Qwen3Config,
749
  past_key_values: Cache,
750
  ):
751
  """
@@ -765,7 +702,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
765
  Indices depicting the position of the input sequence tokens in the sequence.
766
  batch_size (`torch.Tensor`):
767
  Batch size.
768
- config (`Qwen3Config`):
769
  The model's configuration class
770
  past_key_values (`Cache`):
771
  The cache class that is being used currently to generate
@@ -814,14 +751,14 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
814
 
815
 
816
  @auto_docstring
817
- class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
818
  _tied_weights_keys = ["lm_head.weight"]
819
  _tp_plan = {"lm_head": "colwise_rep"}
820
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
821
 
822
  def __init__(self, config):
823
  super().__init__(config)
824
- self.model = Qwen3Model(config)
825
  self.vocab_size = config.vocab_size
826
  self.lm_head = nn.Linear(
827
  config.hidden_size, config.vocab_size, bias=False)
@@ -873,10 +810,10 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
873
  Example:
874
 
875
  ```python
876
- >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
877
 
878
- >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
879
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
880
 
881
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
882
  >>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -958,251 +895,8 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
958
  )
959
 
960
 
961
- @auto_docstring(
962
- custom_intro="""
963
- The Qwen3 Model transformer with a sequence classification head on top (linear layer).
964
-
965
- [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
966
- (e.g. GPT-2) do.
967
-
968
- Since it does classification on the last token, it requires to know the position of the last token. If a
969
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
970
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
971
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
972
- each row of the batch).
973
- """
974
- )
975
- class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
976
- def __init__(self, config):
977
- super().__init__(config)
978
- self.num_labels = config.num_labels
979
- self.model = Qwen3Model(config)
980
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
981
-
982
- # Initialize weights and apply final processing
983
- self.post_init()
984
-
985
- def get_input_embeddings(self):
986
- return self.model.embed_tokens
987
-
988
- def set_input_embeddings(self, value):
989
- self.model.embed_tokens = value
990
-
991
- @can_return_tuple
992
- @auto_docstring
993
- def forward(
994
- self,
995
- input_ids: Optional[torch.LongTensor] = None,
996
- attention_mask: Optional[torch.Tensor] = None,
997
- position_ids: Optional[torch.LongTensor] = None,
998
- past_key_values: Optional[Cache] = None,
999
- inputs_embeds: Optional[torch.FloatTensor] = None,
1000
- labels: Optional[torch.LongTensor] = None,
1001
- use_cache: Optional[bool] = None,
1002
- output_attentions: Optional[bool] = None,
1003
- output_hidden_states: Optional[bool] = None,
1004
- ) -> SequenceClassifierOutputWithPast:
1005
- r"""
1006
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1007
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1008
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1009
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1010
- """
1011
-
1012
- transformer_outputs: BaseModelOutputWithPast = self.model(
1013
- input_ids,
1014
- attention_mask=attention_mask,
1015
- position_ids=position_ids,
1016
- past_key_values=past_key_values,
1017
- inputs_embeds=inputs_embeds,
1018
- use_cache=use_cache,
1019
- output_attentions=output_attentions,
1020
- output_hidden_states=output_hidden_states,
1021
- )
1022
- hidden_states = transformer_outputs.last_hidden_state
1023
- logits = self.score(hidden_states)
1024
-
1025
- if input_ids is not None:
1026
- batch_size = input_ids.shape[0]
1027
- else:
1028
- batch_size = inputs_embeds.shape[0]
1029
-
1030
- if self.config.pad_token_id is None and batch_size != 1:
1031
- raise ValueError(
1032
- "Cannot handle batch sizes > 1 if no padding token is defined.")
1033
- if self.config.pad_token_id is None:
1034
- last_non_pad_token = -1
1035
- elif input_ids is not None:
1036
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1037
- non_pad_mask = (input_ids != self.config.pad_token_id).to(
1038
- logits.device, torch.int32)
1039
- token_indices = torch.arange(
1040
- input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1041
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1042
- else:
1043
- last_non_pad_token = -1
1044
- logger.warning_once(
1045
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1046
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1047
- )
1048
-
1049
- pooled_logits = logits[torch.arange(
1050
- batch_size, device=logits.device), last_non_pad_token]
1051
-
1052
- loss = None
1053
- if labels is not None:
1054
- loss = self.loss_function(
1055
- logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1056
-
1057
- return SequenceClassifierOutputWithPast(
1058
- loss=loss,
1059
- logits=pooled_logits,
1060
- past_key_values=transformer_outputs.past_key_values,
1061
- hidden_states=transformer_outputs.hidden_states,
1062
- attentions=transformer_outputs.attentions,
1063
- )
1064
-
1065
-
1066
- @auto_docstring
1067
- class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
1068
- def __init__(self, config):
1069
- super().__init__(config)
1070
- self.num_labels = config.num_labels
1071
- self.model = Qwen3Model(config)
1072
- if getattr(config, "classifier_dropout", None) is not None:
1073
- classifier_dropout = config.classifier_dropout
1074
- elif getattr(config, "hidden_dropout", None) is not None:
1075
- classifier_dropout = config.hidden_dropout
1076
- else:
1077
- classifier_dropout = 0.1
1078
- self.dropout = nn.Dropout(classifier_dropout)
1079
- self.score = nn.Linear(config.hidden_size, config.num_labels)
1080
-
1081
- # Initialize weights and apply final processing
1082
- self.post_init()
1083
-
1084
- def get_input_embeddings(self):
1085
- return self.model.embed_tokens
1086
-
1087
- def set_input_embeddings(self, value):
1088
- self.model.embed_tokens = value
1089
-
1090
- @can_return_tuple
1091
- @auto_docstring
1092
- def forward(
1093
- self,
1094
- input_ids: Optional[torch.LongTensor] = None,
1095
- attention_mask: Optional[torch.Tensor] = None,
1096
- position_ids: Optional[torch.LongTensor] = None,
1097
- past_key_values: Optional[Cache] = None,
1098
- inputs_embeds: Optional[torch.FloatTensor] = None,
1099
- labels: Optional[torch.LongTensor] = None,
1100
- use_cache: Optional[bool] = None,
1101
- output_attentions: Optional[bool] = None,
1102
- output_hidden_states: Optional[bool] = None,
1103
- ) -> TokenClassifierOutput:
1104
- r"""
1105
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1106
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1107
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1108
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1109
- """
1110
-
1111
- outputs: BaseModelOutputWithPast = self.model(
1112
- input_ids,
1113
- attention_mask=attention_mask,
1114
- position_ids=position_ids,
1115
- past_key_values=past_key_values,
1116
- inputs_embeds=inputs_embeds,
1117
- use_cache=use_cache,
1118
- output_attentions=output_attentions,
1119
- output_hidden_states=output_hidden_states,
1120
- )
1121
- sequence_output = outputs.last_hidden_state
1122
- sequence_output = self.dropout(sequence_output)
1123
- logits = self.score(sequence_output)
1124
-
1125
- loss = None
1126
- if labels is not None:
1127
- loss = self.loss_function(logits, labels, self.config)
1128
-
1129
- return TokenClassifierOutput(
1130
- loss=loss,
1131
- logits=logits,
1132
- hidden_states=outputs.hidden_states,
1133
- attentions=outputs.attentions,
1134
- )
1135
-
1136
-
1137
- @auto_docstring
1138
- class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
1139
- base_model_prefix = "transformer"
1140
-
1141
- def __init__(self, config):
1142
- super().__init__(config)
1143
- self.transformer = Qwen3Model(config)
1144
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1145
-
1146
- # Initialize weights and apply final processing
1147
- self.post_init()
1148
-
1149
- def get_input_embeddings(self):
1150
- return self.transformer.embed_tokens
1151
-
1152
- def set_input_embeddings(self, value):
1153
- self.transformer.embed_tokens = value
1154
-
1155
- @can_return_tuple
1156
- @auto_docstring
1157
- def forward(
1158
- self,
1159
- input_ids: Optional[torch.LongTensor] = None,
1160
- attention_mask: Optional[torch.Tensor] = None,
1161
- position_ids: Optional[torch.LongTensor] = None,
1162
- past_key_values: Optional[Cache] = None,
1163
- inputs_embeds: Optional[torch.FloatTensor] = None,
1164
- start_positions: Optional[torch.LongTensor] = None,
1165
- end_positions: Optional[torch.LongTensor] = None,
1166
- output_attentions: Optional[bool] = None,
1167
- output_hidden_states: Optional[bool] = None,
1168
- **kwargs,
1169
- ) -> QuestionAnsweringModelOutput:
1170
- outputs: BaseModelOutputWithPast = self.transformer(
1171
- input_ids,
1172
- attention_mask=attention_mask,
1173
- position_ids=position_ids,
1174
- past_key_values=past_key_values,
1175
- inputs_embeds=inputs_embeds,
1176
- output_attentions=output_attentions,
1177
- output_hidden_states=output_hidden_states,
1178
- )
1179
-
1180
- sequence_output = outputs.last_hidden_state
1181
-
1182
- logits = self.qa_outputs(sequence_output)
1183
- start_logits, end_logits = logits.split(1, dim=-1)
1184
- start_logits = start_logits.squeeze(-1).contiguous()
1185
- end_logits = end_logits.squeeze(-1).contiguous()
1186
-
1187
- loss = None
1188
- if start_positions is not None and end_positions is not None:
1189
- loss = self.loss_function(
1190
- start_logits, end_logits, start_positions, end_positions, **kwargs)
1191
-
1192
- return QuestionAnsweringModelOutput(
1193
- loss=loss,
1194
- start_logits=start_logits,
1195
- end_logits=end_logits,
1196
- hidden_states=outputs.hidden_states,
1197
- attentions=outputs.attentions,
1198
- )
1199
-
1200
-
1201
  __all__ = [
1202
- "Qwen3ForCausalLM",
1203
- "Qwen3ForQuestionAnswering",
1204
- "Qwen3Model",
1205
- "Qwen3PreTrainedModel",
1206
- "Qwen3ForSequenceClassification",
1207
- "Qwen3ForTokenClassification",
1208
  ]
 
1
+ # This file is modified based on https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3/modeling_qwen3.py.
2
+ #
3
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
4
  # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
5
  # Do NOT edit this file manually as any edits will be overwritten by the generation of
 
44
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
  from transformers.processing_utils import Unpack
46
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
+ from .configuration_sdar import SDARConfig
48
 
49
  from fla.modules.activations import swiglu_linear
50
  from fla.modules import (
 
83
 
84
 
85
  @use_kernel_forward_from_hub("RMSNorm")
86
+ class SDARRMSNorm(nn.Module):
87
  def __init__(self, hidden_size, eps=1e-6):
88
  """
89
+ SDARRMSNorm is equivalent to T5LayerNorm
90
  """
91
  super().__init__()
92
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
109
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
110
 
111
 
112
+ class SDARMLP(nn.Module):
113
  def __init__(self, config):
114
  super().__init__()
115
  self.config = config
 
207
  return attn_output, attn_weights
208
 
209
 
210
+ class SDARAttention(nn.Module):
211
  """Multi-headed attention from 'Attention Is All You Need' paper"""
212
 
213
+ def __init__(self, config: SDARConfig, layer_idx: int):
214
  super().__init__()
215
  self.config = config
216
  self.layer_idx = layer_idx
 
238
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
239
  )
240
  # unlike olmo, only on the head dim!
241
+ self.q_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps)
242
  # thus post q_norm does not need reshape
243
+ self.k_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps)
244
  self.sliding_window = config.sliding_window
245
  if not (
246
  self.config.use_sliding_window
 
277
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
278
  key_states, value_states = past_key_value.update(
279
  key_states, value_states, self.layer_idx)
280
+ elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:
281
+ # only retrive, do not store kv
282
  past_key_states, past_value_states = past_key_value[self.layer_idx]
283
  key_states = torch.cat(
284
  [past_key_states, key_states], dim=-2
 
286
  value_states = torch.cat(
287
  [past_value_states, value_states], dim=-2
288
  )
289
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  attention_mask = attention_mask.bool() if attention_mask is not None else None
291
+ if torch.all(attention_mask): # decoding
292
  query_states = query_states.transpose(1, 2)
293
  key_states = key_states.transpose(1, 2)
294
  value_states = value_states.transpose(1, 2)
 
299
  causal=False,
300
  softmax_scale=self.scaling)
301
 
302
+ else: # prefilling
303
  attn_output = F.scaled_dot_product_attention(
304
  query=query_states,
305
  key=key_states,
 
315
  return attn_output, None #, attn_weights
316
 
317
 
318
+ class SDARDecoderLayer(GradientCheckpointingLayer):
319
+ def __init__(self, config: SDARConfig, layer_idx: int):
320
  super().__init__()
321
  self.hidden_size = config.hidden_size
322
+ self.self_attn = SDARAttention(config=config, layer_idx=layer_idx)
323
+ self.mlp = SDARMLP(config)
324
+ self.input_layernorm = SDARRMSNorm(
325
  config.hidden_size, eps=config.rms_norm_eps)
326
+ self.post_attention_layernorm = SDARRMSNorm(
327
  config.hidden_size, eps=config.rms_norm_eps)
328
  if (
329
  config.sliding_window and config._attn_implementation != "flash_attention_2"
 
380
 
381
 
382
  @auto_docstring
383
+ class SDARPreTrainedModel(PreTrainedModel):
384
+ config_class = SDARConfig
385
  base_model_prefix = "model"
386
  supports_gradient_checkpointing = True
387
+ _no_split_modules = ["SDARDecoderLayer"]
388
  _skip_keys_device_placement = ["past_key_values"]
389
  _supports_flash_attn_2 = True
390
  _supports_sdpa = True
 
404
  module.weight.data.normal_(mean=0.0, std=std)
405
  if module.padding_idx is not None:
406
  module.weight.data[module.padding_idx].zero_()
407
+ elif isinstance(module, SDARRMSNorm):
408
  module.weight.data.fill_(1.0)
409
 
410
 
411
+ class SDARRotaryEmbedding(nn.Module):
412
+ def __init__(self, config: SDARConfig, device=None):
413
  super().__init__()
414
  # BC: "rope_type" was originally "type"
415
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 
449
 
450
 
451
  @auto_docstring
452
+ class SDARModel(SDARPreTrainedModel):
453
+ def __init__(self, config: SDARConfig):
454
  super().__init__(config)
455
  self.padding_idx = config.pad_token_id
456
  self.vocab_size = config.vocab_size
 
458
  self.embed_tokens = nn.Embedding(
459
  config.vocab_size, config.hidden_size, self.padding_idx)
460
  self.layers = nn.ModuleList(
461
+ [SDARDecoderLayer(config, layer_idx)
462
  for layer_idx in range(config.num_hidden_layers)]
463
  )
464
+ self.norm = SDARRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
465
+ self.rotary_emb = SDARRotaryEmbedding(config=config)
466
  self.gradient_checkpointing = False
467
 
468
  # Initialize weights and apply final processing
 
682
  dtype: torch.dtype,
683
  cache_position: torch.Tensor,
684
  batch_size: int,
685
+ config: SDARConfig,
686
  past_key_values: Cache,
687
  ):
688
  """
 
702
  Indices depicting the position of the input sequence tokens in the sequence.
703
  batch_size (`torch.Tensor`):
704
  Batch size.
705
+ config (`SDARConfig`):
706
  The model's configuration class
707
  past_key_values (`Cache`):
708
  The cache class that is being used currently to generate
 
751
 
752
 
753
  @auto_docstring
754
+ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
755
  _tied_weights_keys = ["lm_head.weight"]
756
  _tp_plan = {"lm_head": "colwise_rep"}
757
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
758
 
759
  def __init__(self, config):
760
  super().__init__(config)
761
+ self.model = SDARModel(config)
762
  self.vocab_size = config.vocab_size
763
  self.lm_head = nn.Linear(
764
  config.hidden_size, config.vocab_size, bias=False)
 
810
  Example:
811
 
812
  ```python
813
+ >>> from transformers import AutoTokenizer, SDARForCausalLM
814
 
815
+ >>> model = SDARForCausalLM.from_pretrained("DiffuOpen/SDAR-1.7B-Chat")
816
+ >>> tokenizer = AutoTokenizer.from_pretrained("DiffuOpen/SDAR-1.7B-Chat")
817
 
818
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
819
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
895
  )
896
 
897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898
  __all__ = [
899
+ "SDARForCausalLM",
900
+ "SDARModel",
901
+ "SDARPreTrainedModel",
 
 
 
902
  ]