Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from typing import List, Optional, Union, Tuple | |
from transformers import LlamaConfig | |
from transformers.cache_utils import Cache, DynamicCache, StaticCache | |
from transformers.utils import logging | |
from transformers.modeling_outputs import ( | |
SequenceClassifierOutputWithPast, | |
BaseModelOutputWithPast | |
) | |
from transformers.models.llama.modeling_llama import ( | |
LlamaDecoderLayer, | |
LlamaRMSNorm, | |
LlamaRotaryEmbedding, | |
LlamaPreTrainedModel | |
) | |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
# Local | |
from sae import TopkSAE, pre_process, Normalized_MSE_loss, Masked_Normalized_MSE_loss | |
logger = logging.get_logger(__name__) | |
#========================================================================================================================================================================== | |
#========================================================================================================================================================================== | |
class MyLlamaModel(LlamaPreTrainedModel): | |
def __init__( | |
self, | |
config: LlamaConfig, | |
hidden_state_source_layer: int=None | |
): | |
if hidden_state_source_layer==None: | |
# default 1/2 | |
hidden_state_source_layer = int(config.num_hidden_layers/2) | |
super().__init__(config) | |
self.hidden_state_source_layer = hidden_state_source_layer | |
self.padding_idx = config.pad_token_id | |
self.vocab_size = config.vocab_size | |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
self.layers = nn.ModuleList( | |
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(hidden_state_source_layer)] | |
) | |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
self.rotary_emb = LlamaRotaryEmbedding(config=config) | |
self.gradient_checkpointing = False | |
if getattr(config, "pretraining_tp", 1) != 1: | |
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.embed_tokens | |
def set_input_embeddings(self, value): | |
self.embed_tokens = value | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if (input_ids is None) ^ (inputs_embeds is not None): | |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
if self.gradient_checkpointing and self.training and use_cache: | |
logger.warning_once( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." | |
) | |
use_cache = False | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
# kept for BC (non `Cache` `past_key_values` inputs) | |
return_legacy_cache = False | |
if use_cache and not isinstance(past_key_values, Cache): | |
return_legacy_cache = True | |
if past_key_values is None: | |
past_key_values = DynamicCache() | |
else: | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
logger.warning_once( | |
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | |
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " | |
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" | |
) | |
if cache_position is None: | |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
cache_position = torch.arange( | |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
) | |
if position_ids is None: | |
position_ids = cache_position.unsqueeze(0) | |
causal_mask = self._update_causal_mask( | |
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | |
) | |
hidden_states = inputs_embeds | |
# create position embeddings to be shared across the decoder layers | |
position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = None | |
for decoder_layer in self.layers: | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
layer_outputs = self._gradient_checkpointing_func( | |
decoder_layer.__call__, | |
hidden_states, | |
causal_mask, | |
position_ids, | |
past_key_values, | |
output_attentions, | |
use_cache, | |
cache_position, | |
position_embeddings, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=causal_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cache_position=cache_position, | |
position_embeddings=position_embeddings, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
# hidden_states = self.norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
if return_legacy_cache: | |
next_cache = next_cache.to_legacy_cache() | |
if not return_dict: | |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=next_cache, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
) | |
def _update_causal_mask( | |
self, | |
attention_mask: torch.Tensor, | |
input_tensor: torch.Tensor, | |
cache_position: torch.Tensor, | |
past_key_values: Cache, | |
output_attentions: bool, | |
): | |
if self.config._attn_implementation == "flash_attention_2": | |
if attention_mask is not None and 0.0 in attention_mask: | |
return attention_mask | |
return None | |
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in | |
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | |
# to infer the attention mask. | |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
using_static_cache = isinstance(past_key_values, StaticCache) | |
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward | |
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: | |
if AttentionMaskConverter._ignore_causal_mask_sdpa( | |
attention_mask, | |
inputs_embeds=input_tensor, | |
past_key_values_length=past_seen_tokens, | |
is_training=self.training, | |
): | |
return None | |
dtype, device = input_tensor.dtype, input_tensor.device | |
sequence_length = input_tensor.shape[1] | |
if using_static_cache: | |
target_length = past_key_values.get_max_cache_shape() | |
else: | |
target_length = ( | |
attention_mask.shape[-1] | |
if isinstance(attention_mask, torch.Tensor) | |
else past_seen_tokens + sequence_length + 1 | |
) | |
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D). | |
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( | |
attention_mask, | |
sequence_length=sequence_length, | |
target_length=target_length, | |
dtype=dtype, | |
device=device, | |
cache_position=cache_position, | |
batch_size=input_tensor.shape[0], | |
) | |
if ( | |
self.config._attn_implementation == "sdpa" | |
and attention_mask is not None | |
and attention_mask.device.type == "cuda" | |
and not output_attentions | |
): | |
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when | |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. | |
# Details: https://github.com/pytorch/pytorch/issues/110213 | |
min_dtype = torch.finfo(dtype).min | |
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) | |
return causal_mask | |
def _prepare_4d_causal_attention_mask_with_cache_position( | |
attention_mask: torch.Tensor, | |
sequence_length: int, | |
target_length: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
cache_position: torch.Tensor, | |
batch_size: int, | |
**kwargs, | |
): | |
""" | |
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | |
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. | |
Args: | |
attention_mask (`torch.Tensor`): | |
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape | |
`(batch_size, 1, query_length, key_value_length)`. | |
sequence_length (`int`): | |
The sequence length being processed. | |
target_length (`int`): | |
The target length: when generating with static cache, the mask should be as long as the static cache, | |
to account for the 0 padding, the part of the cache that is not filled yet. | |
dtype (`torch.dtype`): | |
The dtype to use for the 4D attention mask. | |
device (`torch.device`): | |
The device to plcae the 4D attention mask on. | |
cache_position (`torch.Tensor`): | |
Indices depicting the position of the input sequence tokens in the sequence. | |
batch_size (`torch.Tensor`): | |
Batch size. | |
""" | |
if attention_mask is not None and attention_mask.dim() == 4: | |
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. | |
causal_mask = attention_mask | |
else: | |
min_dtype = torch.finfo(dtype).min | |
causal_mask = torch.full( | |
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device | |
) | |
if sequence_length != 1: | |
causal_mask = torch.triu(causal_mask, diagonal=1) | |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | |
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) | |
if attention_mask is not None: | |
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | |
mask_length = attention_mask.shape[-1] | |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] | |
padding_mask = padding_mask == 0 | |
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( | |
padding_mask, min_dtype | |
) | |
return causal_mask | |
#========================================================================================================================================================================== | |
#============================================ 从LlamaForSequenceClassification为原型,修改为SAE4RM的形式 ============================================= | |
#========================================================================================================================================================================== | |
class LlamaSARM(LlamaPreTrainedModel): | |
def __init__( | |
self, config, sae_hidden_state_source_layer, sae_latent_size, sae_k, | |
sae_use_sequence_level=False, | |
sarm_use_topk=False, | |
sarm_train_mode=1 | |
): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = MyLlamaModel(config, hidden_state_source_layer=sae_hidden_state_source_layer) | |
self.sae_use_sequence_level = sae_use_sequence_level | |
self.sarm_use_topk = sarm_use_topk | |
self.sarm_train_mode = sarm_train_mode | |
self.score = nn.Linear(sae_latent_size, self.num_labels, bias=False) | |
self.sae = TopkSAE(hidden_size=self.model.config.hidden_size, latent_size=sae_latent_size, k=sae_k) | |
if self.sarm_train_mode==0: | |
for p in self.model.parameters(): | |
p.requires_grad_(False) | |
if self.sarm_train_mode==0 or self.sarm_train_mode==1: | |
for p in self.sae.parameters(): | |
p.requires_grad_(False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
assistant_masks: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
transformer_outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] | |
h, _, _ = pre_process(hidden_states) | |
sae_features = self.sae.pre_acts(h) | |
if self.sarm_use_topk: | |
sae_features = self.sae.get_latents(sae_features) | |
logits = self.score(sae_features) | |
if input_ids is not None: | |
batch_size = input_ids.shape[0] | |
else: | |
batch_size = inputs_embeds.shape[0] | |
if self.config.pad_token_id is None and batch_size != 1: | |
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | |
if self.config.pad_token_id is None: | |
sequence_lengths = -1 | |
else: | |
if input_ids is not None: | |
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | |
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | |
sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
sequence_lengths = sequence_lengths.to(logits.device) | |
else: | |
sequence_lengths = -1 | |
# ensure last_token is <|eot_id|> | |
assert ((input_ids[torch.arange(batch_size, device=logits.device), sequence_lengths]!=torch.ones(batch_size, device=logits.device)*128009).sum() == 0).item() | |
# joint training | |
rec_loss = None | |
if self.sarm_train_mode==2: | |
if not self.sarm_use_topk: | |
sae_features_t = self.sae.get_latents(sae_features) | |
h_hat = self.sae.decode(sae_features_t) | |
rec_loss = Masked_Normalized_MSE_loss(h, h_hat, assistant_masks) | |
elif self.sarm_train_mode==3 and not self.sae_use_sequence_level: | |
h_d = h.detach() | |
_, h_hat = self.sae(h_d) | |
rec_loss = Masked_Normalized_MSE_loss(h_d, h_hat, assistant_masks) | |
elif self.sarm_train_mode==3 and self.sae_use_sequence_level: | |
h_d = h.detach() | |
sequence_lengths_t = sequence_lengths.view(-1,1,1) | |
last_token_mask = torch.zeros([h_d.shape[0] ,1 ,h_d.shape[1]], device=h_d.device) | |
last_token_mask.scatter_(-1, sequence_lengths_t, torch.ones_like(sequence_lengths_t, dtype=last_token_mask.dtype)) | |
# h_d -> (bs, seq_len, d), last_token_mask -> (bs, 1, seq_len) | |
h_d = torch.matmul(last_token_mask.to(h_d.dtype), h_d) | |
_, h_hat = self.sae(h_d) | |
rec_loss = Normalized_MSE_loss(h_d, h_hat) | |
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | |
loss = None | |
if labels is not None: | |
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) | |
if rec_loss is not None: | |
loss = rec_loss | |
if not return_dict: | |
output = (pooled_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutputWithPast( | |
loss=loss, | |
logits=pooled_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
) | |
#========================================================================================================================================================================== | |
#================================= 从LlamaForSequenceClassification为原型,可以放在任意层的score head(两层MLP) ======================================== | |
#========================================================================================================================================================================== | |
class LlamaBaseline(LlamaPreTrainedModel): | |
def __init__( | |
self, config, sae_hidden_state_source_layer, sae_latent_size | |
): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = MyLlamaModel(config, hidden_state_source_layer=sae_hidden_state_source_layer) | |
self.untrained_sae_encoder = nn.Linear(self.model.config.hidden_size, sae_latent_size) | |
self.score = nn.Linear(sae_latent_size, self.num_labels, bias=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
assistant_masks: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
transformer_outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] | |
logits = self.score(self.untrained_sae_encoder(hidden_states)) | |
if input_ids is not None: | |
batch_size = input_ids.shape[0] | |
else: | |
batch_size = inputs_embeds.shape[0] | |
if self.config.pad_token_id is None and batch_size != 1: | |
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | |
if self.config.pad_token_id is None: | |
sequence_lengths = -1 | |
else: | |
if input_ids is not None: | |
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | |
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | |
sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
sequence_lengths = sequence_lengths.to(logits.device) | |
else: | |
sequence_lengths = -1 | |
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | |
loss = None | |
if labels is not None: | |
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) | |
if not return_dict: | |
output = (pooled_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutputWithPast( | |
loss=loss, | |
logits=pooled_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
) | |
class LlamaBaselineFrozen(LlamaPreTrainedModel): | |
def __init__( | |
self, config, sae_hidden_state_source_layer, sae_latent_size | |
): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = MyLlamaModel(config, hidden_state_source_layer=sae_hidden_state_source_layer) | |
self.untrained_sae_encoder = nn.Linear(self.model.config.hidden_size, sae_latent_size) | |
self.score = nn.Linear(sae_latent_size, self.num_labels, bias=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
transformer_outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] | |
logits = self.score(self.untrained_sae_encoder(hidden_states)) | |
if input_ids is not None: | |
batch_size = input_ids.shape[0] | |
else: | |
batch_size = inputs_embeds.shape[0] | |
if self.config.pad_token_id is None and batch_size != 1: | |
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | |
if self.config.pad_token_id is None: | |
sequence_lengths = -1 | |
else: | |
if input_ids is not None: | |
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | |
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | |
sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
sequence_lengths = sequence_lengths.to(logits.device) | |
else: | |
sequence_lengths = -1 | |
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | |
loss = None | |
if labels is not None: | |
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) | |
if not return_dict: | |
output = (pooled_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutputWithPast( | |
loss=loss, | |
logits=pooled_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
) |