File size: 4,332 Bytes
0e956f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
from transformers import AutoTokenizer, RobertaModel
class EmbeddingModel(torch.nn.Module):
tokenizers = {'roberta': RobertaModel}
"""
A basic wrapper around a Hugging Face transformer model.
Takes a string as input and produces an embedding vector of size d.
"""
def __init__(self, config, **kwargs):
super().__init__()
self.model_class = self.tokenizers.get(config.get("model_class").lower())
self.model_name = config.get("model_name")
self.pooling = config.get("pooling")
self.normalize = config.get("normalize")
self.d = config.get("d")
self.prompt = config.get("prompt")
self.add_upper = config.get("add_upper")
self.upper_case = config.get("upper_case")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
try:
self.transformer = self.model_class.from_pretrained(self.model_name)
except OSError:
self.transformer = self.model_class.from_pretrained(self.model_name,from_tf=True)
self.dropout = torch.nn.Dropout(0.5)
if self.d:
# Project embedding to a lower dimension
# Initialization based on random projection LSH (preserves approximate cosine distances)
self.projection = torch.nn.Linear(self.transformer.config.hidden_size,self.d)
torch.nn.init.normal_(self.projection.weight)
torch.nn.init.constant_(self.projection.bias,0)
self.to(config.get("device"))
def to(self,device):
super().to(device)
self.device = device
def encode(self,strings):
if self.prompt is not None:
strings = [self.prompt + s for s in strings]
if self.add_upper:
strings = [s + ' </s> ' + s.upper() for s in strings]
if self.upper_case:
strings = [s + ' </s> ' + s.upper() for s in strings]
try:
encoded = self.tokenizer(strings,padding=True,truncation=True)
except Exception as e:
print(strings)
raise Exception(e)
input_ids = torch.tensor(encoded['input_ids']).long()
attention_mask = torch.tensor(encoded['attention_mask'])
return input_ids,attention_mask
def forward(self,strings):
with torch.no_grad():
input_ids,attention_mask = self.encode(strings)
input_ids = input_ids.to(device=self.device)
attention_mask = attention_mask.to(device=self.device)
# with amp.autocast(self.amp):
batch_out = self.transformer(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True)
if self.pooling == 'pooler':
v = batch_out['pooler_output']
elif self.pooling == 'mean':
h = batch_out['last_hidden_state']
# Compute mean of unmasked token vectors
h = h*attention_mask[:,:,None]
v = h.sum(dim=1)/attention_mask.sum(dim=1)[:,None]
if self.d:
v = self.projection(v)
if self.normalize:
v = v/torch.sqrt((v**2).sum(dim=1)[:,None])
return v
def config_optimizer(self,transformer_lr=1e-5,projection_lr=1e-4):
parameters = list(self.named_parameters())
grouped_parameters = [
{
'params': [param for name,param in parameters if name.startswith('transformer') and name.endswith('bias')],
'weight_decay_rate': 0.0,
'lr':transformer_lr,
},
{
'params': [param for name,param in parameters if name.startswith('transformer') and not name.endswith('bias')],
'weight_decay_rate': 0.0,
'lr':transformer_lr,
},
{
'params': [param for name,param in parameters if name.startswith('projection')],
'weight_decay_rate': 0.0,
'lr':projection_lr,
},
]
# Drop groups with lr of 0
grouped_parameters = [p for p in grouped_parameters if p['lr']]
optimizer = torch.optim.AdamW(grouped_parameters)
return optimizer |