|
from sentence_transformers.models import Transformer |
|
import torch |
|
from transformers.utils.import_utils import is_peft_available |
|
|
|
|
|
class LayerSpecificTransformer(Transformer): |
|
def forward( |
|
self, features: dict[str, torch.Tensor], layer_idx: int = -1, **kwargs |
|
) -> dict[str, torch.Tensor]: |
|
"""Returns token_embeddings, cls_token""" |
|
trans_features = { |
|
key: value |
|
for key, value in features.items() |
|
if key in ["input_ids", "attention_mask", "token_type_ids", "inputs_embeds"] |
|
} |
|
|
|
output_states = self.auto_model( |
|
**trans_features, **kwargs, return_dict=True, output_hidden_states=True |
|
) |
|
output_tokens = output_states.hidden_states[layer_idx] |
|
|
|
|
|
|
|
if is_peft_available(): |
|
from peft import PeftModelForFeatureExtraction |
|
|
|
if ( |
|
isinstance(self.auto_model, PeftModelForFeatureExtraction) |
|
and self.auto_model.active_peft_config.is_prompt_learning |
|
): |
|
batch_size = output_tokens.size(0) |
|
attention_mask = features["attention_mask"] |
|
prefix_attention_mask = torch.ones( |
|
batch_size, |
|
self.auto_model.active_peft_config.num_virtual_tokens, |
|
device=attention_mask.device, |
|
) |
|
features["attention_mask"] = torch.cat( |
|
(prefix_attention_mask, attention_mask), dim=1 |
|
) |
|
|
|
features["token_embeddings"] = output_tokens |
|
|
|
if self.auto_model.config.output_hidden_states and len(output_states) > 2: |
|
all_layer_idx = 2 |
|
if ( |
|
len(output_states) < 3 |
|
): |
|
all_layer_idx = 1 |
|
|
|
hidden_states = output_states[all_layer_idx] |
|
features["all_layer_embeddings"] = hidden_states |
|
|
|
return features |
|
|