Schrieffer2sy commited on
Commit
ec9b1de
·
1 Parent(s): 6e891fa
Files changed (3) hide show
  1. app.py +3 -2
  2. sae.py +102 -0
  3. sarm_llama.py +649 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer
 
4
 
5
  # --- 1. 加载模型和Tokenizer ---
6
  # 这一步会自动从Hugging Face Hub下载你的模型文件
@@ -12,7 +13,7 @@ MODEL_ID = "schrieffer/SARM-4B"
12
  print(f"Loading model: {MODEL_ID} on {DEVICE}...")
13
 
14
  # 加载模型时必须信任远程代码,因为SARM有自定义架构
15
- model = AutoModel.from_pretrained(
16
  MODEL_ID,
17
  device_map=DEVICE,
18
  trust_remote_code=True,
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer
4
+ from sarm_llama import LlamaSARM
5
 
6
  # --- 1. 加载模型和Tokenizer ---
7
  # 这一步会自动从Hugging Face Hub下载你的模型文件
 
13
  print(f"Loading model: {MODEL_ID} on {DEVICE}...")
14
 
15
  # 加载模型时必须信任远程代码,因为SARM有自定义架构
16
+ model = LlamaSARM.from_pretrained(
17
  MODEL_ID,
18
  device_map=DEVICE,
19
  trust_remote_code=True,
sae.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def get_last_assistant_masks(input_ids):
5
+ i=len(input_ids)-4
6
+ while i >= 0:
7
+ if input_ids[i:i+4] == [128006, 78191, 128007, 271]:
8
+ pos = i + 4
9
+ break
10
+ i -= 1
11
+
12
+ assistant_masks = []
13
+ for i in range(len(input_ids)):
14
+ if i < pos:
15
+ assistant_masks.append(0)
16
+ else:
17
+ assistant_masks.append(1)
18
+
19
+ assert input_ids[-1]==128009
20
+ return assistant_masks
21
+
22
+ def Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
23
+ return (((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)).mean()
24
+
25
+ def Masked_Normalized_MSE_loss(x: torch.Tensor, x_hat: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
26
+ mask = mask.to(torch.bfloat16)
27
+ loss = ((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)
28
+ assert loss.shape==mask.shape
29
+ seq_loss = (mask * loss).sum(-1) / (mask.sum(-1))
30
+ return seq_loss.mean()
31
+
32
+ def pre_process(hidden_stats: torch.Tensor, eps: float = 1e-6) -> tuple:
33
+ '''
34
+ :param hidden_stats: Hidden states (shape: [batch, max_length, hidden_size]).
35
+ :param eps: Epsilon value for numerical stability.
36
+ '''
37
+ mean = hidden_stats.mean(dim=-1, keepdim=True)
38
+ std = hidden_stats.std(dim=-1, keepdim=True)
39
+ x = (hidden_stats - mean) / (std + eps)
40
+ return x, mean, std
41
+
42
+ class TopkSAE(nn.Module):
43
+ '''
44
+ TopK Sparse Autoencoder Implements:
45
+ z = TopK(encoder(x - pre_bias) + latent_bias)
46
+ x_hat = decoder(z) + pre_bias
47
+ '''
48
+ def __init__(
49
+ self, hidden_size: int, latent_size: int, k: int
50
+ ) -> None:
51
+ '''
52
+ :param hidden_size: Dimensionality of the input residual stream activation.
53
+ :param latent_size: Number of latent units.
54
+ :param k: Number of activated latents.
55
+ '''
56
+
57
+ # 'sae_pre_bias', 'sae_latent_bias', 'sae_encoder.weight', 'sae_decoder.weight'
58
+
59
+ assert k <= latent_size, f'k should be less than or equal to {latent_size}'
60
+ super(TopkSAE, self).__init__()
61
+ self.pre_bias = nn.Parameter(torch.zeros(hidden_size))
62
+ self.latent_bias = nn.Parameter(torch.zeros(latent_size))
63
+ self.encoder = nn.Linear(hidden_size, latent_size, bias=False)
64
+ self.decoder = nn.Linear(latent_size, hidden_size, bias=False)
65
+
66
+ self.k = k
67
+ self.latent_size = latent_size
68
+ self.hidden_size = hidden_size
69
+
70
+ # "tied" init
71
+ # self.decoder.weight.data = self.encoder.weight.data.T.clone()
72
+
73
+ def pre_acts(self, x: torch.Tensor) -> torch.Tensor:
74
+ x = x - self.pre_bias
75
+ return self.encoder(x) + self.latent_bias
76
+
77
+ def get_latents(self, pre_acts: torch.Tensor) -> torch.Tensor:
78
+ topk = torch.topk(pre_acts, self.k, dim=-1)
79
+ latents = torch.zeros_like(pre_acts)
80
+ latents.scatter_(-1, topk.indices, topk.values)
81
+ return latents
82
+
83
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
84
+ pre_acts = self.pre_acts(x)
85
+ latents = self.get_latents(pre_acts)
86
+ return latents
87
+
88
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
89
+ return self.decoder(latents) + self.pre_bias
90
+
91
+ def forward(self, x: torch.Tensor) -> tuple:
92
+ '''
93
+ :param x: Input residual stream activation (shape: [batch_size, max_length, hidden_size]).
94
+ :return: latents (shape: [batch_size, max_length, latent_size]).
95
+ x_hat (shape: [batch_size, max_length, hidden_size]).
96
+ '''
97
+ latents = self.encode(x)
98
+ x_hat = self.decode(latents)
99
+ return latents, x_hat
100
+
101
+
102
+
sarm_llama.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import List, Optional, Union, Tuple
5
+ from transformers import LlamaConfig
6
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
7
+ from transformers.utils import logging
8
+ from transformers.modeling_outputs import (
9
+ SequenceClassifierOutputWithPast,
10
+ BaseModelOutputWithPast
11
+ )
12
+ from transformers.models.llama.modeling_llama import (
13
+ LlamaDecoderLayer,
14
+ LlamaRMSNorm,
15
+ LlamaRotaryEmbedding,
16
+ LlamaPreTrainedModel
17
+ )
18
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19
+
20
+ # Local
21
+ from sae import TopkSAE, pre_process, Normalized_MSE_loss, Masked_Normalized_MSE_loss
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+ #==========================================================================================================================================================================
26
+ #==========================================================================================================================================================================
27
+ class MyLlamaModel(LlamaPreTrainedModel):
28
+ def __init__(
29
+ self,
30
+ config: LlamaConfig,
31
+ hidden_state_source_layer: int=None
32
+ ):
33
+ if hidden_state_source_layer==None:
34
+ # default 1/2
35
+ hidden_state_source_layer = int(config.num_hidden_layers/2)
36
+
37
+ super().__init__(config)
38
+ self.hidden_state_source_layer = hidden_state_source_layer
39
+ self.padding_idx = config.pad_token_id
40
+ self.vocab_size = config.vocab_size
41
+
42
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
43
+ self.layers = nn.ModuleList(
44
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(hidden_state_source_layer)]
45
+ )
46
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
47
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
48
+ self.gradient_checkpointing = False
49
+ if getattr(config, "pretraining_tp", 1) != 1:
50
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
51
+
52
+ # Initialize weights and apply final processing
53
+ self.post_init()
54
+
55
+ def get_input_embeddings(self):
56
+ return self.embed_tokens
57
+
58
+ def set_input_embeddings(self, value):
59
+ self.embed_tokens = value
60
+
61
+ def forward(
62
+ self,
63
+ input_ids: torch.LongTensor = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ position_ids: Optional[torch.LongTensor] = None,
66
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
67
+ inputs_embeds: Optional[torch.FloatTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ return_dict: Optional[bool] = None,
72
+ cache_position: Optional[torch.LongTensor] = None,
73
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
77
+ )
78
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
79
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
+
81
+ if (input_ids is None) ^ (inputs_embeds is not None):
82
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
83
+
84
+ if self.gradient_checkpointing and self.training and use_cache:
85
+ logger.warning_once(
86
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
87
+ )
88
+ use_cache = False
89
+
90
+ if inputs_embeds is None:
91
+ inputs_embeds = self.embed_tokens(input_ids)
92
+
93
+ # kept for BC (non `Cache` `past_key_values` inputs)
94
+ return_legacy_cache = False
95
+ if use_cache and not isinstance(past_key_values, Cache):
96
+ return_legacy_cache = True
97
+ if past_key_values is None:
98
+ past_key_values = DynamicCache()
99
+ else:
100
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
101
+ logger.warning_once(
102
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
103
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
104
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
105
+ )
106
+
107
+ if cache_position is None:
108
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
109
+ cache_position = torch.arange(
110
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
111
+ )
112
+ if position_ids is None:
113
+ position_ids = cache_position.unsqueeze(0)
114
+
115
+ causal_mask = self._update_causal_mask(
116
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
117
+ )
118
+ hidden_states = inputs_embeds
119
+
120
+ # create position embeddings to be shared across the decoder layers
121
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
122
+
123
+ # decoder layers
124
+ all_hidden_states = () if output_hidden_states else None
125
+ all_self_attns = () if output_attentions else None
126
+ next_decoder_cache = None
127
+
128
+
129
+ for decoder_layer in self.layers:
130
+ if output_hidden_states:
131
+ all_hidden_states += (hidden_states,)
132
+
133
+ if self.gradient_checkpointing and self.training:
134
+ layer_outputs = self._gradient_checkpointing_func(
135
+ decoder_layer.__call__,
136
+ hidden_states,
137
+ causal_mask,
138
+ position_ids,
139
+ past_key_values,
140
+ output_attentions,
141
+ use_cache,
142
+ cache_position,
143
+ position_embeddings,
144
+ )
145
+ else:
146
+ layer_outputs = decoder_layer(
147
+ hidden_states,
148
+ attention_mask=causal_mask,
149
+ position_ids=position_ids,
150
+ past_key_value=past_key_values,
151
+ output_attentions=output_attentions,
152
+ use_cache=use_cache,
153
+ cache_position=cache_position,
154
+ position_embeddings=position_embeddings,
155
+ )
156
+
157
+ hidden_states = layer_outputs[0]
158
+
159
+ if use_cache:
160
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
161
+
162
+ if output_attentions:
163
+ all_self_attns += (layer_outputs[1],)
164
+
165
+ # hidden_states = self.norm(hidden_states)
166
+
167
+ # add hidden states from the last decoder layer
168
+ if output_hidden_states:
169
+ all_hidden_states += (hidden_states,)
170
+
171
+ next_cache = next_decoder_cache if use_cache else None
172
+ if return_legacy_cache:
173
+ next_cache = next_cache.to_legacy_cache()
174
+
175
+ if not return_dict:
176
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
177
+ return BaseModelOutputWithPast(
178
+ last_hidden_state=hidden_states,
179
+ past_key_values=next_cache,
180
+ hidden_states=all_hidden_states,
181
+ attentions=all_self_attns,
182
+ )
183
+
184
+ def _update_causal_mask(
185
+ self,
186
+ attention_mask: torch.Tensor,
187
+ input_tensor: torch.Tensor,
188
+ cache_position: torch.Tensor,
189
+ past_key_values: Cache,
190
+ output_attentions: bool,
191
+ ):
192
+ if self.config._attn_implementation == "flash_attention_2":
193
+ if attention_mask is not None and 0.0 in attention_mask:
194
+ return attention_mask
195
+ return None
196
+
197
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
198
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
199
+ # to infer the attention mask.
200
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
201
+ using_static_cache = isinstance(past_key_values, StaticCache)
202
+
203
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
204
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
205
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
206
+ attention_mask,
207
+ inputs_embeds=input_tensor,
208
+ past_key_values_length=past_seen_tokens,
209
+ is_training=self.training,
210
+ ):
211
+ return None
212
+
213
+ dtype, device = input_tensor.dtype, input_tensor.device
214
+ sequence_length = input_tensor.shape[1]
215
+ if using_static_cache:
216
+ target_length = past_key_values.get_max_cache_shape()
217
+ else:
218
+ target_length = (
219
+ attention_mask.shape[-1]
220
+ if isinstance(attention_mask, torch.Tensor)
221
+ else past_seen_tokens + sequence_length + 1
222
+ )
223
+
224
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
225
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
226
+ attention_mask,
227
+ sequence_length=sequence_length,
228
+ target_length=target_length,
229
+ dtype=dtype,
230
+ device=device,
231
+ cache_position=cache_position,
232
+ batch_size=input_tensor.shape[0],
233
+ )
234
+
235
+ if (
236
+ self.config._attn_implementation == "sdpa"
237
+ and attention_mask is not None
238
+ and attention_mask.device.type == "cuda"
239
+ and not output_attentions
240
+ ):
241
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
242
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
243
+ # Details: https://github.com/pytorch/pytorch/issues/110213
244
+ min_dtype = torch.finfo(dtype).min
245
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
246
+
247
+ return causal_mask
248
+
249
+ @staticmethod
250
+ def _prepare_4d_causal_attention_mask_with_cache_position(
251
+ attention_mask: torch.Tensor,
252
+ sequence_length: int,
253
+ target_length: int,
254
+ dtype: torch.dtype,
255
+ device: torch.device,
256
+ cache_position: torch.Tensor,
257
+ batch_size: int,
258
+ **kwargs,
259
+ ):
260
+ """
261
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
262
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
263
+
264
+ Args:
265
+ attention_mask (`torch.Tensor`):
266
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
267
+ `(batch_size, 1, query_length, key_value_length)`.
268
+ sequence_length (`int`):
269
+ The sequence length being processed.
270
+ target_length (`int`):
271
+ The target length: when generating with static cache, the mask should be as long as the static cache,
272
+ to account for the 0 padding, the part of the cache that is not filled yet.
273
+ dtype (`torch.dtype`):
274
+ The dtype to use for the 4D attention mask.
275
+ device (`torch.device`):
276
+ The device to plcae the 4D attention mask on.
277
+ cache_position (`torch.Tensor`):
278
+ Indices depicting the position of the input sequence tokens in the sequence.
279
+ batch_size (`torch.Tensor`):
280
+ Batch size.
281
+ """
282
+ if attention_mask is not None and attention_mask.dim() == 4:
283
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
284
+ causal_mask = attention_mask
285
+ else:
286
+ min_dtype = torch.finfo(dtype).min
287
+ causal_mask = torch.full(
288
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
289
+ )
290
+ if sequence_length != 1:
291
+ causal_mask = torch.triu(causal_mask, diagonal=1)
292
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
293
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
294
+ if attention_mask is not None:
295
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
296
+ mask_length = attention_mask.shape[-1]
297
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
298
+ padding_mask = padding_mask == 0
299
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
300
+ padding_mask, min_dtype
301
+ )
302
+
303
+ return causal_mask
304
+
305
+
306
+
307
+
308
+ #==========================================================================================================================================================================
309
+ #============================================ 从LlamaForSequenceClassification为原型,修改为SAE4RM的形式 =============================================
310
+ #==========================================================================================================================================================================
311
+
312
+
313
+ class LlamaSARM(LlamaPreTrainedModel):
314
+ def __init__(
315
+ self, config, sae_hidden_state_source_layer, sae_latent_size, sae_k,
316
+ sae_use_sequence_level=False,
317
+ sarm_use_topk=False,
318
+ sarm_train_mode=1
319
+ ):
320
+ super().__init__(config)
321
+ self.num_labels = config.num_labels
322
+ self.model = MyLlamaModel(config, hidden_state_source_layer=sae_hidden_state_source_layer)
323
+
324
+ self.sae_use_sequence_level = sae_use_sequence_level
325
+ self.sarm_use_topk = sarm_use_topk
326
+ self.sarm_train_mode = sarm_train_mode
327
+
328
+ self.score = nn.Linear(sae_latent_size, self.num_labels, bias=False)
329
+ self.sae = TopkSAE(hidden_size=self.model.config.hidden_size, latent_size=sae_latent_size, k=sae_k)
330
+
331
+ if self.sarm_train_mode==0:
332
+ for p in self.model.parameters():
333
+ p.requires_grad_(False)
334
+ if self.sarm_train_mode==0 or self.sarm_train_mode==1:
335
+ for p in self.sae.parameters():
336
+ p.requires_grad_(False)
337
+
338
+ # Initialize weights and apply final processing
339
+ self.post_init()
340
+
341
+
342
+ def get_input_embeddings(self):
343
+ return self.model.embed_tokens
344
+
345
+ def set_input_embeddings(self, value):
346
+ self.model.embed_tokens = value
347
+
348
+
349
+ def forward(
350
+ self,
351
+ input_ids: Optional[torch.LongTensor] = None,
352
+ attention_mask: Optional[torch.Tensor] = None,
353
+ assistant_masks: Optional[torch.Tensor] = None,
354
+ position_ids: Optional[torch.LongTensor] = None,
355
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
356
+ inputs_embeds: Optional[torch.FloatTensor] = None,
357
+ labels: Optional[torch.LongTensor] = None,
358
+ use_cache: Optional[bool] = None,
359
+ output_attentions: Optional[bool] = None,
360
+ output_hidden_states: Optional[bool] = None,
361
+ return_dict: Optional[bool] = None,
362
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
363
+ r"""
364
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
365
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
366
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
367
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
368
+ """
369
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
370
+
371
+ transformer_outputs = self.model(
372
+ input_ids,
373
+ attention_mask=attention_mask,
374
+ position_ids=position_ids,
375
+ past_key_values=past_key_values,
376
+ inputs_embeds=inputs_embeds,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ return_dict=return_dict,
381
+ )
382
+ hidden_states = transformer_outputs[0]
383
+
384
+
385
+ h, _, _ = pre_process(hidden_states)
386
+ sae_features = self.sae.pre_acts(h)
387
+ if self.sarm_use_topk:
388
+ sae_features = self.sae.get_latents(sae_features)
389
+
390
+
391
+ logits = self.score(sae_features)
392
+
393
+ if input_ids is not None:
394
+ batch_size = input_ids.shape[0]
395
+ else:
396
+ batch_size = inputs_embeds.shape[0]
397
+
398
+ if self.config.pad_token_id is None and batch_size != 1:
399
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
400
+ if self.config.pad_token_id is None:
401
+ sequence_lengths = -1
402
+ else:
403
+ if input_ids is not None:
404
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
405
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
406
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
407
+ sequence_lengths = sequence_lengths.to(logits.device)
408
+ else:
409
+ sequence_lengths = -1
410
+ # ensure last_token is <|eot_id|>
411
+ assert ((input_ids[torch.arange(batch_size, device=logits.device), sequence_lengths]!=torch.ones(batch_size, device=logits.device)*128009).sum() == 0).item()
412
+
413
+ # joint training
414
+ rec_loss = None
415
+ if self.sarm_train_mode==2:
416
+ if not self.sarm_use_topk:
417
+ sae_features_t = self.sae.get_latents(sae_features)
418
+ h_hat = self.sae.decode(sae_features_t)
419
+ rec_loss = Masked_Normalized_MSE_loss(h, h_hat, assistant_masks)
420
+ elif self.sarm_train_mode==3 and not self.sae_use_sequence_level:
421
+ h_d = h.detach()
422
+ _, h_hat = self.sae(h_d)
423
+ rec_loss = Masked_Normalized_MSE_loss(h_d, h_hat, assistant_masks)
424
+ elif self.sarm_train_mode==3 and self.sae_use_sequence_level:
425
+ h_d = h.detach()
426
+ sequence_lengths_t = sequence_lengths.view(-1,1,1)
427
+ last_token_mask = torch.zeros([h_d.shape[0] ,1 ,h_d.shape[1]], device=h_d.device)
428
+ last_token_mask.scatter_(-1, sequence_lengths_t, torch.ones_like(sequence_lengths_t, dtype=last_token_mask.dtype))
429
+
430
+ # h_d -> (bs, seq_len, d), last_token_mask -> (bs, 1, seq_len)
431
+ h_d = torch.matmul(last_token_mask.to(h_d.dtype), h_d)
432
+
433
+ _, h_hat = self.sae(h_d)
434
+ rec_loss = Normalized_MSE_loss(h_d, h_hat)
435
+
436
+
437
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
438
+
439
+
440
+ loss = None
441
+ if labels is not None:
442
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
443
+ if rec_loss is not None:
444
+ loss = rec_loss
445
+
446
+ if not return_dict:
447
+ output = (pooled_logits,) + transformer_outputs[1:]
448
+ return ((loss,) + output) if loss is not None else output
449
+
450
+ return SequenceClassifierOutputWithPast(
451
+ loss=loss,
452
+ logits=pooled_logits,
453
+ past_key_values=transformer_outputs.past_key_values,
454
+ hidden_states=transformer_outputs.hidden_states,
455
+ attentions=transformer_outputs.attentions,
456
+ )
457
+
458
+
459
+
460
+ #==========================================================================================================================================================================
461
+ #================================= 从LlamaForSequenceClassification为原型,可以放在任意层的score head(两层MLP) ========================================
462
+ #==========================================================================================================================================================================
463
+ class LlamaBaseline(LlamaPreTrainedModel):
464
+ def __init__(
465
+ self, config, sae_hidden_state_source_layer, sae_latent_size
466
+ ):
467
+ super().__init__(config)
468
+ self.num_labels = config.num_labels
469
+ self.model = MyLlamaModel(config, hidden_state_source_layer=sae_hidden_state_source_layer)
470
+
471
+ self.untrained_sae_encoder = nn.Linear(self.model.config.hidden_size, sae_latent_size)
472
+ self.score = nn.Linear(sae_latent_size, self.num_labels, bias=False)
473
+
474
+ # Initialize weights and apply final processing
475
+ self.post_init()
476
+
477
+
478
+ def get_input_embeddings(self):
479
+ return self.model.embed_tokens
480
+
481
+ def set_input_embeddings(self, value):
482
+ self.model.embed_tokens = value
483
+
484
+
485
+ def forward(
486
+ self,
487
+ input_ids: Optional[torch.LongTensor] = None,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ assistant_masks: Optional[torch.Tensor] = None,
490
+ position_ids: Optional[torch.LongTensor] = None,
491
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
492
+ inputs_embeds: Optional[torch.FloatTensor] = None,
493
+ labels: Optional[torch.LongTensor] = None,
494
+ use_cache: Optional[bool] = None,
495
+ output_attentions: Optional[bool] = None,
496
+ output_hidden_states: Optional[bool] = None,
497
+ return_dict: Optional[bool] = None,
498
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
499
+ r"""
500
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
501
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
502
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
503
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
504
+ """
505
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
506
+
507
+ transformer_outputs = self.model(
508
+ input_ids,
509
+ attention_mask=attention_mask,
510
+ position_ids=position_ids,
511
+ past_key_values=past_key_values,
512
+ inputs_embeds=inputs_embeds,
513
+ use_cache=use_cache,
514
+ output_attentions=output_attentions,
515
+ output_hidden_states=output_hidden_states,
516
+ return_dict=return_dict,
517
+ )
518
+ hidden_states = transformer_outputs[0]
519
+ logits = self.score(self.untrained_sae_encoder(hidden_states))
520
+
521
+ if input_ids is not None:
522
+ batch_size = input_ids.shape[0]
523
+ else:
524
+ batch_size = inputs_embeds.shape[0]
525
+
526
+ if self.config.pad_token_id is None and batch_size != 1:
527
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
528
+ if self.config.pad_token_id is None:
529
+ sequence_lengths = -1
530
+ else:
531
+ if input_ids is not None:
532
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
533
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
534
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
535
+ sequence_lengths = sequence_lengths.to(logits.device)
536
+ else:
537
+ sequence_lengths = -1
538
+
539
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
540
+
541
+ loss = None
542
+ if labels is not None:
543
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
544
+
545
+ if not return_dict:
546
+ output = (pooled_logits,) + transformer_outputs[1:]
547
+ return ((loss,) + output) if loss is not None else output
548
+
549
+ return SequenceClassifierOutputWithPast(
550
+ loss=loss,
551
+ logits=pooled_logits,
552
+ past_key_values=transformer_outputs.past_key_values,
553
+ hidden_states=transformer_outputs.hidden_states,
554
+ attentions=transformer_outputs.attentions,
555
+ )
556
+
557
+
558
+ class LlamaBaselineFrozen(LlamaPreTrainedModel):
559
+ def __init__(
560
+ self, config, sae_hidden_state_source_layer, sae_latent_size
561
+ ):
562
+ super().__init__(config)
563
+ self.num_labels = config.num_labels
564
+ self.model = MyLlamaModel(config, hidden_state_source_layer=sae_hidden_state_source_layer)
565
+
566
+ self.untrained_sae_encoder = nn.Linear(self.model.config.hidden_size, sae_latent_size)
567
+ self.score = nn.Linear(sae_latent_size, self.num_labels, bias=False)
568
+
569
+ # Initialize weights and apply final processing
570
+ self.post_init()
571
+
572
+
573
+ def get_input_embeddings(self):
574
+ return self.model.embed_tokens
575
+
576
+ def set_input_embeddings(self, value):
577
+ self.model.embed_tokens = value
578
+
579
+
580
+ def forward(
581
+ self,
582
+ input_ids: Optional[torch.LongTensor] = None,
583
+ attention_mask: Optional[torch.Tensor] = None,
584
+ position_ids: Optional[torch.LongTensor] = None,
585
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
586
+ inputs_embeds: Optional[torch.FloatTensor] = None,
587
+ labels: Optional[torch.LongTensor] = None,
588
+ use_cache: Optional[bool] = None,
589
+ output_attentions: Optional[bool] = None,
590
+ output_hidden_states: Optional[bool] = None,
591
+ return_dict: Optional[bool] = None,
592
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
593
+ r"""
594
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
595
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
596
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
597
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
598
+ """
599
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
600
+
601
+ transformer_outputs = self.model(
602
+ input_ids,
603
+ attention_mask=attention_mask,
604
+ position_ids=position_ids,
605
+ past_key_values=past_key_values,
606
+ inputs_embeds=inputs_embeds,
607
+ use_cache=use_cache,
608
+ output_attentions=output_attentions,
609
+ output_hidden_states=output_hidden_states,
610
+ return_dict=return_dict,
611
+ )
612
+ hidden_states = transformer_outputs[0]
613
+ logits = self.score(self.untrained_sae_encoder(hidden_states))
614
+
615
+ if input_ids is not None:
616
+ batch_size = input_ids.shape[0]
617
+ else:
618
+ batch_size = inputs_embeds.shape[0]
619
+
620
+ if self.config.pad_token_id is None and batch_size != 1:
621
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
622
+ if self.config.pad_token_id is None:
623
+ sequence_lengths = -1
624
+ else:
625
+ if input_ids is not None:
626
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
627
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
628
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
629
+ sequence_lengths = sequence_lengths.to(logits.device)
630
+ else:
631
+ sequence_lengths = -1
632
+
633
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
634
+
635
+ loss = None
636
+ if labels is not None:
637
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
638
+
639
+ if not return_dict:
640
+ output = (pooled_logits,) + transformer_outputs[1:]
641
+ return ((loss,) + output) if loss is not None else output
642
+
643
+ return SequenceClassifierOutputWithPast(
644
+ loss=loss,
645
+ logits=pooled_logits,
646
+ past_key_values=transformer_outputs.past_key_values,
647
+ hidden_states=transformer_outputs.hidden_states,
648
+ attentions=transformer_outputs.attentions,
649
+ )