from math import sqrt,log import sys import torch import torch.nn as nn from torch.nn.functional import softmax,relu,linear, gelu from common import PositionalEncoding from hopfield import HopfieldLayer, HopfieldMHA, HopfieldReLU, HopfieldSoftmax from configuration_energy import BertEnergyConfig from torch.cuda.amp import autocast import yaml from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutput ACT2FN={'relu': relu, 'gelu': gelu, 'softmax': softmax} class BertModel(PreTrainedModel): """ Backbone of standard BERT model outputs : last hidden state, history""" config_class = BertEnergyConfig def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): super().__init__(config) self.Emb_in = nn.Embedding(config.vocabulary_size, config.embedding_dim, padding_idx=pad_idx) self.posn = PositionalEncoding(config.embedding_dim, max_len=config.block_size,dropout=config.dropout) if config.positional else None if config.share_layers: # ALBERT config self.embedding_hidden_in = nn.Linear(config.embedding_dim, config.forward_memories) if config.share_layers else None # Albert uses two matrices instead of one for embeddings see 3.1 in Albert paper # Albert normalise and penalise embeddings self.embed_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm) self.embed_dropout = nn.Dropout(config.dropout) self.num_layers = config.num_layers self.share_layers = config.share_layers if config.share_layers: layer = nn.TransformerEncoderLayer(config.forward_memories, config.num_heads, activation=config.activation, dim_feedforward=config.forward_memories*4, dropout=config.dropout, layer_norm_eps=config.layer_norm, batch_first=True, norm_first=True, ) self.layers = nn.ModuleList([layer]) else: self.layers = nn.ModuleList([nn.TransformerEncoderLayer(config.embedding_dim, config.num_heads, dim_feedforward=config.forward_memories*4, dropout=config.dropout, layer_norm_eps=config.layer_norm, batch_first=True, norm_first=True, ) for _ in range(config.num_layers)]) def forward(self,input_ids, attention_mask=None, **kwargs): """ Warning : expect attention mask with 0 pad tokens -> mismatch Pytorch/HF tokenizer""" xbatch = self.Emb_in(input_ids) if self.posn: X = xbatch + self.posn(xbatch) else: X = xbatch if self.share_layers: X = self.embed_norm(X) X = self.embed_dropout(X) X = self.embedding_hidden_in(X) history = None if self.training else [X] # WARNING attention_mask = ~attention_mask.bool() # Mismatch between HF tokenizer and Torch attention mask https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer for i in range(self.num_layers): if self.share_layers: layer = self.layers[0] else: layer = self.layers[i] X = layer(X, src_key_padding_mask=attention_mask) if not self.training: history.append(X) # TODO add return attention return BaseModelOutput(last_hidden_state=X, hidden_states=history, attentions=None) class BertModelForMaskedLM(PreTrainedModel): """ Bert model to be trained on the MLM task. Based on the backbone Bert model + projection on the vocabulary with tied weight and norm outputs: cross entropy loss / logits / hidden states """ config_class = BertEnergyConfig ignore_index = -100 _tied_weights_keys = ["Emb_out.weight", "Emb_out.bias"] def __init__(self, config, add_pooling_layer=True, pad_idx=None): super().__init__(config) self.config = config self.model = BertModel(config, pad_idx=pad_idx) self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm) self.dense = nn.Linear(config.forward_memories, config.embedding_dim) self.activation = ACT2FN[config.activation] """ if config.tie_weights: self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size, bias=False) self.tie_weights() else: self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size) self.bias = nn.Parameter(torch.zeros(config.vocabulary_size)) self.Emb_out.bias = self.bias """ self.Emb_out = nn.Linear(config.forward_memories, config.vocabulary_size) self.bias = nn.Parameter(torch.zeros(config.vocabulary_size)) self.Emb_out.bias = self.bias def get_input_embeddings(self): return self.model.Emb_in def set_output_embeddings(self, new_embeddings): self.Emb_out = new_embeddings def forward(self,input_ids, attention_mask=None, labels=None, **kwargs): outputs = self.model(input_ids, attention_mask, **kwargs) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.hidden_states attentions = outputs.attentions last_hidden_state = self.dense(last_hidden_state) last_hidden_state = self.activation(last_hidden_state) last_hidden_state = self.norm(last_hidden_state) """ if self.config.tie_weights: logits = last_hidden_state @ self.Emb_out.weight.transpose(-1,-2) else: logits = self.Emb_out(last_hidden_state) """ logits = self.Emb_out(last_hidden_state) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocabulary_size), labels.view(-1)) return MaskedLMOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=attentions) class BertModelForSequenceClassification(PreTrainedModel): """ Bert model to be trained on Sequence classification tasks. Based on the backbone Bert model + projection on the vocabulary with tied weight and norm outputs: cross entropy loss / logits / hidden states """ config_class = BertEnergyConfig ignore_index = -100 def __init__(self, config, add_pooling_layer=True, pad_idx=None, num_labels=2, classifier_dropout=None, return_dict=True): super().__init__(config) self.config = config self.num_labels = num_labels self.classifier_dropout = classifier_dropout self.return_dict = return_dict self.model = BertModel(config, pad_idx=pad_idx) self.dense = nn.Linear(config.forward_memories, config.forward_memories) classifier_dropout = ( classifier_dropout if classifier_dropout is not None else config.dropout ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.forward_memories,num_labels) self.norm = nn.LayerNorm(config.embedding_dim) #self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size, bias=False) #self.Emb_out.weight = self.model.Emb_in.weight # weight tying def forward(self,input_ids, labels=None, return_dict=False, **kwargs): outputs = self.model(input_ids, **kwargs) last_hidden_state = self.norm(outputs.last_hidden_state) # Code from roberta : https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/roberta/modeling_roberta.py#L1426 x = last_hidden_state[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) logits = self.classifier(x) hidden_states = outputs.hidden_states attentions = outputs.attentions loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def compute_loss(self, logits, labels): # code from https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L494 log_probs = -nn.functional.log_softmax(logits, dim=-1) if labels.dim() == log_probs.dim() - 1: labels = labels.unsqueeze(-1) padding_mask = labels.eq(self.ignore_index) # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask # will ignore them in any case. labels = torch.clamp(labels, min=0) nll_loss = log_probs.gather(dim=-1, index=labels) nll_loss.masked_fill_(padding_mask, 0.0) num_active_elements = padding_mask.numel() - padding_mask.long().sum() nll_loss = nll_loss.sum() / num_active_elements return nll_loss class BertEnergyModel(PreTrainedModel): config_class = BertEnergyConfig def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): super().__init__(config) self.Emb_in = nn.Embedding(config.vocabulary_size, config.embedding_dim, padding_idx=pad_idx) self.posn = PositionalEncoding(config.embedding_dim,max_len=config.block_size,dropout=config.dropout) if config.positional else None self.num_layers = config.num_layers self.layer = HopfieldLayer(config.embedding_dim,config.num_heads,forward_memories=config.forward_memories,forward_activation=config.activation,bias=config.bias,beta=config.beta,dropout=config.dropout) self.alpha = config.alpha def forward(self,input_ids, attention_mask=None, **kwargs): xbatch = self.Emb_in(input_ids) if self.posn: X = xbatch + self.posn(xbatch) else: X = xbatch history = None if self.training else [X] for _ in range(self.num_layers): #TODO add src_key pad attention mask X = X - self.alpha * self.layer(X, src_key_padding_mask=attention_mask, is_causal=False) if not self.training: history.append(X) return BaseModelOutput(last_hidden_state=X, hidden_states=history, attentions=None) class BertEnergyModelForMaskedLM(PreTrainedModel): config_class = BertEnergyConfig ignore_index = -100 _tied_weights_keys = ["Emb_out.weight", "Emb_out.bias"] def __init__(self, config, add_pooling_layer=True, pad_idx=None): super().__init__(config) self.config = config self.model = BertEnergyModel(config, pad_idx=pad_idx) self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm) self.dense = nn.Linear(config.embedding_dim, config.embedding_dim) self.activation = ACT2FN[config.activation] self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size) self.bias = nn.Parameter(torch.zeros(config.vocabulary_size)) self.Emb_out.bias = self.bias def get_input_embeddings(self): return self.model.Emb_in def set_output_embeddings(self, new_embeddings): self.Emb_out = new_embeddings def forward(self,input_ids, attention_mask=None, labels=None, **kwargs ): outputs = self.model(input_ids , attention_mask=attention_mask) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.hidden_states attentions = outputs.attentions last_hidden_state = self.dense(last_hidden_state) last_hidden_state = gelu(last_hidden_state) #XXX last_hidden_state = self.norm(last_hidden_state) #logits = self.norm(last_hidden_state) @ self.Emb_out.weight.transpose(-1,-2) if self.config.tie_weights: logits = last_hidden_state @ self.Emb_out.weight.transpose(-1,-2) else: logits = self.Emb_out(last_hidden_state) loss = None hidden_states = hidden_states attentions = None #if labels is not None: # loss = self.compute_loss(logits, labels) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocabulary_size), labels.view(-1)) return MaskedLMOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=attentions) if __name__ == '__main__': def grads(f, x): """ Autograd used for the energy """ return torch.func.jacrev(f)(x) #from test import * x = torch.randn(1,10) input_ids = torch.tensor([[3,12,44, 2]]) #test relu #print('relu') #hrelu = HopfieldReLU(10,4,bias=False) #print(hrelu(x),hrelu.energy(x)) #print(grads(hrelu.energy,x)) #test softmax #print('softmax') #hsoftmax = HopfieldSoftmax(10,4,bias=None) #print(hsoftmax(x),hsoftmax.energy(x)) #print(grads(hsoftmax.energy,x)) #test MHA #print('mha') #mha = HopfieldMHA(15,3) #X = torch.randn(2,4,15) #causal = True #print(mha(X,is_causal=causal),mha.energy(X,is_causal=causal)) #print() #print('=== Ref=== ') #for x in X: #autograd breaks with higher order tensors # print(grads(lambda y: mha.energy(y,is_causal=causal) ,x)) config = HopfieldConfig(path="../lmconfig.yaml") print(config) #exit() mdl = HFHopfieldModel(config) mdl.eval() #print(mdl) out = mdl(input_ids) print(out[0].mean()) mdl.save_pretrained("test_checkpoint") reloaded = HFHopfieldModel.from_pretrained("test_checkpoint") out_reloaded = reloaded(input_ids) print(out_reloaded[0].mean()) reloaded.to("cuda:0") print(reloaded(input_ids.to("cuda:0"))[0])