from sentence_transformers.models import Transformer import torch from transformers.utils.import_utils import is_peft_available class LayerSpecificTransformer(Transformer): def forward( self, features: dict[str, torch.Tensor], layer_idx: int = -1, **kwargs ) -> dict[str, torch.Tensor]: """Returns token_embeddings, cls_token""" trans_features = { key: value for key, value in features.items() if key in ["input_ids", "attention_mask", "token_type_ids", "inputs_embeds"] } output_states = self.auto_model( **trans_features, **kwargs, return_dict=True, output_hidden_states=True ) output_tokens = output_states.hidden_states[layer_idx] # -1: last layer, -2: second from last # If the AutoModel is wrapped with a PeftModelForFeatureExtraction, then it may have added virtual tokens # We need to extend the attention mask to include these virtual tokens, or the pooling will fail if is_peft_available(): from peft import PeftModelForFeatureExtraction if ( isinstance(self.auto_model, PeftModelForFeatureExtraction) and self.auto_model.active_peft_config.is_prompt_learning ): batch_size = output_tokens.size(0) attention_mask = features["attention_mask"] prefix_attention_mask = torch.ones( batch_size, self.auto_model.active_peft_config.num_virtual_tokens, device=attention_mask.device, ) features["attention_mask"] = torch.cat( (prefix_attention_mask, attention_mask), dim=1 ) features["token_embeddings"] = output_tokens if self.auto_model.config.output_hidden_states and len(output_states) > 2: all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output if ( len(output_states) < 3 ): # Some models only output last_hidden_states and all_hidden_states all_layer_idx = 1 hidden_states = output_states[all_layer_idx] features["all_layer_embeddings"] = hidden_states return features