| from transformers import PreTrainedModel, PretrainedConfig |
| from sentence_transformers import SentenceTransformer |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
|
|
| class ZeroShotEmbeddingConfig(PretrainedConfig): |
| model_type = "embedding-head" |
|
|
| def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs): |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.output_size = output_size |
| self.base_embedding_model = base_embedding_model |
| super().__init__(**kwargs) |
|
|
|
|
| class ZeroShotEmbedding(PreTrainedModel): |
| config_class = ZeroShotEmbeddingConfig |
|
|
| def __init__(self, config): |
| super(ZeroShotEmbedding, self).__init__(config) |
|
|
| input_size = config.input_size |
| hidden_size = config.hidden_size |
| output_size = config.output_size |
|
|
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.output_size = output_size |
| |
| self.fc1 = nn.Linear(input_size * 2, hidden_size) |
| self.fc2 = nn.Linear(hidden_size, output_size) |
| self.gelu = nn.GELU() |
|
|
| def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs): |
| |
| |
| |
|
|
| |
| |
| x = torch.cat((text_a_embedding, prompt_embedding), dim=1) |
| if text_b_embedding is not None: |
| |
| |
| x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1) |
|
|
| |
| x = self.fc1(x) |
| x = self.gelu(x) |
| x = self.fc2(x) |
| x = nn.functional.normalize(x, p=2, dim=1) |
| if text_b_embedding is not None: |
| x2 = self.fc1(x2) |
| x2 = self.gelu(x2) |
| x2 = self.fc2(x2) |
| x2 = nn.functional.normalize(x2, p=2, dim=1) |
| |
| dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze() |
| if labels is not None: |
| |
| loss = torch.mean((dot_product - labels) ** 2) |
| return loss, dot_product |
| return dot_product |
| return x |
|
|
|
|
| class ZeroShotEmbeddingForClustering(PreTrainedModel): |
| config_class = ZeroShotEmbeddingConfig |
|
|
| def __init__(self, config): |
| super(ZeroShotEmbeddingForClustering, self).__init__(config) |
| self.base_embedding_model = SentenceTransformer( |
| config.base_embedding_model) |
| self.head_model = ZeroShotEmbedding(config) |
|
|
| def forward(self, texts, prompt, **kwargs): |
| text_embeddings = self.base_embedding_model.encode(texts) |
| prompt_embedding = self.base_embedding_model.encode(prompt) |
| prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1)) |
| text_embeddings = torch.tensor(text_embeddings) |
| prompt_embeddings = torch.tensor(prompt_embeddings) |
| prompted_embeddings = self.head_model( |
| prompt_embeddings, text_embeddings) |
| similarity = torch.mm(prompted_embeddings, |
| prompted_embeddings.transpose(0, 1)) |
| return similarity |
|
|
| @classmethod |
| def from_pretrained_base(cls, pretrained_model_name_or_path): |
| head_model = ZeroShotEmbedding.from_pretrained( |
| pretrained_model_name_or_path) |
| model = cls(head_model.config) |
| cls.head_model = head_model |
| return model |
|
|
|
|
| ZeroShotEmbeddingConfig.register_for_auto_class() |
| ZeroShotEmbedding.register_for_auto_class("AutoModel") |
| ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel") |
|
|