File size: 4,828 Bytes
39a3276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from typing import Optional, Tuple
import torch.nn.functional as F
from transformers import BatchEncoding
from transformers import MPNetTokenizerFast
from transformers.models.roformer.modeling_roformer import (
RoFormerEmbeddings,
RoFormerModel,
RoFormerEncoder,
RoFormerLayer,
RoFormerAttention,
RoFormerIntermediate,
RoFormerOutput,
RoFormerSelfAttention,
RoFormerPreTrainedModel
)
from transformers.models.mpnet.modeling_mpnet import MPNetModel
class JRoFormerEmbeddings(RoFormerEmbeddings):
"""Construct the embeddings from word and token_type embeddings."""
def __init__(self, config):
super().__init__(config)
self.word_embeddings = nn.Embedding(
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
)
self.token_type_embeddings = self.word_embeddings
class JRoFormerSelfAttention(RoFormerSelfAttention):
def __init__(self, config):
super().__init__(config)
self.query = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.key = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.value = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
class JRoFormerAttention(RoFormerAttention):
def __init__(self, config):
super().__init__(config)
self.self = JRoFormerSelfAttention(config)
class JRoFormerLayer(RoFormerLayer):
def __init__(self, config):
super().__init__(config)
self.attention = JRoFormerAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(
f"{self} should be used as a decoder model if cross attention is added"
)
self.crossattention = RoFormerAttention(config)
self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config)
class JRoFormerEncoder(RoFormerEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList(
[JRoFormerLayer(config) for _ in range(config.num_hidden_layers)]
)
class JRoFormerModel(RoFormerModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = JRoFormerEmbeddings(config)
if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(
config.embedding_size, config.hidden_size
)
self.encoder = JRoFormerEncoder(config)
# Initialize weights and apply final processing
self.post_init()
class AsmEncoder(RoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.roformer = JRoFormerModel(config)
self.projection = nn.Linear(config.hidden_size, config.bla_dim)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
asm_embedding = self.projection(asm_embedding)
asm_embedding = F.normalize(asm_embedding, p=2, dim=1)
return asm_embedding
|