|
import torch |
|
from transformers import AutoTokenizer, RobertaModel |
|
|
|
class EmbeddingModel(torch.nn.Module): |
|
|
|
|
|
tokenizers = {'roberta': RobertaModel} |
|
""" |
|
A basic wrapper around a Hugging Face transformer model. |
|
Takes a string as input and produces an embedding vector of size d. |
|
""" |
|
def __init__(self, config, **kwargs): |
|
|
|
super().__init__() |
|
|
|
self.model_class = self.tokenizers.get(config.get("model_class").lower()) |
|
self.model_name = config.get("model_name") |
|
self.pooling = config.get("pooling") |
|
self.normalize = config.get("normalize") |
|
self.d = config.get("d") |
|
self.prompt = config.get("prompt") |
|
self.add_upper = config.get("add_upper") |
|
self.upper_case = config.get("upper_case") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
try: |
|
self.transformer = self.model_class.from_pretrained(self.model_name) |
|
except OSError: |
|
self.transformer = self.model_class.from_pretrained(self.model_name,from_tf=True) |
|
|
|
self.dropout = torch.nn.Dropout(0.5) |
|
|
|
if self.d: |
|
|
|
|
|
self.projection = torch.nn.Linear(self.transformer.config.hidden_size,self.d) |
|
torch.nn.init.normal_(self.projection.weight) |
|
torch.nn.init.constant_(self.projection.bias,0) |
|
|
|
self.to(config.get("device")) |
|
|
|
def to(self,device): |
|
super().to(device) |
|
self.device = device |
|
|
|
def encode(self,strings): |
|
if self.prompt is not None: |
|
strings = [self.prompt + s for s in strings] |
|
if self.add_upper: |
|
strings = [s + ' </s> ' + s.upper() for s in strings] |
|
if self.upper_case: |
|
strings = [s + ' </s> ' + s.upper() for s in strings] |
|
|
|
try: |
|
encoded = self.tokenizer(strings,padding=True,truncation=True) |
|
except Exception as e: |
|
print(strings) |
|
raise Exception(e) |
|
input_ids = torch.tensor(encoded['input_ids']).long() |
|
attention_mask = torch.tensor(encoded['attention_mask']) |
|
|
|
return input_ids,attention_mask |
|
|
|
def forward(self,strings): |
|
|
|
with torch.no_grad(): |
|
input_ids,attention_mask = self.encode(strings) |
|
|
|
input_ids = input_ids.to(device=self.device) |
|
attention_mask = attention_mask.to(device=self.device) |
|
|
|
|
|
batch_out = self.transformer(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True) |
|
|
|
if self.pooling == 'pooler': |
|
v = batch_out['pooler_output'] |
|
elif self.pooling == 'mean': |
|
h = batch_out['last_hidden_state'] |
|
|
|
|
|
h = h*attention_mask[:,:,None] |
|
v = h.sum(dim=1)/attention_mask.sum(dim=1)[:,None] |
|
|
|
if self.d: |
|
v = self.projection(v) |
|
|
|
if self.normalize: |
|
v = v/torch.sqrt((v**2).sum(dim=1)[:,None]) |
|
|
|
return v |
|
|
|
def config_optimizer(self,transformer_lr=1e-5,projection_lr=1e-4): |
|
|
|
parameters = list(self.named_parameters()) |
|
grouped_parameters = [ |
|
{ |
|
'params': [param for name,param in parameters if name.startswith('transformer') and name.endswith('bias')], |
|
'weight_decay_rate': 0.0, |
|
'lr':transformer_lr, |
|
}, |
|
{ |
|
'params': [param for name,param in parameters if name.startswith('transformer') and not name.endswith('bias')], |
|
'weight_decay_rate': 0.0, |
|
'lr':transformer_lr, |
|
}, |
|
{ |
|
'params': [param for name,param in parameters if name.startswith('projection')], |
|
'weight_decay_rate': 0.0, |
|
'lr':projection_lr, |
|
}, |
|
] |
|
|
|
|
|
grouped_parameters = [p for p in grouped_parameters if p['lr']] |
|
|
|
optimizer = torch.optim.AdamW(grouped_parameters) |
|
|
|
return optimizer |