import os from functools import partial import torch import torch.nn as nn from einops import rearrange, reduce, repeat from transformers import AutoModel import torch.nn.functional as F # from models.layers import MemoryEfficientAttention, SelfAttention from huggingface_hub import PyTorchModelHubMixin from transformers import AutoModel, PreTrainedModel from .config import LUARConfig from huggingface_hub import PyTorchModelHubMixin import math class UARPlay(PreTrainedModel): """Defines the SBERT model. """ config_class = LUARConfig def __init__(self, config): super().__init__(config) self.create_transformer() self.linear = nn.Linear(self.hidden_size, config.embedding_size) def attn_fn(self, k, q ,v) : d_k = q.size(-1) scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) p_attn = F.softmax(scores, dim=-1) return torch.matmul(p_attn, v) def create_transformer(self): """Creates the Transformer model. """ self.transformer = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1") self.hidden_size = self.transformer.config.hidden_size self.num_attention_heads = self.transformer.config.num_attention_heads self.dim_head = self.hidden_size // self.num_attention_heads def mean_pooling(self, token_embeddings, attention_mask): """Mean Pooling as described in the SBERT paper. """ input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float() sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) return sum_embeddings / sum_mask def get_episode_embeddings(self, data): """Computes the Author Embedding. """ # batch_size, num_sample_per_author, episode_length input_ids, attention_mask = data[0].unsqueeze(1), data[1].unsqueeze(1) B, N, E, _ = input_ids.shape input_ids = rearrange(input_ids, 'b n e l -> (b n e) l') attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l') outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True ) # at this point, we're embedding individual "comments" comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) comment_embeddings = rearrange(comment_embeddings, '(b n e) l -> (b n) e l', b=B, n=N, e=E) # aggregate individual comments embeddings into episode embeddings episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') episode_embeddings = self.linear(episode_embeddings) return episode_embeddings, comment_embeddings def forward(self, input_ids, attention_mask): """Calculates a fixed-length feature vector for a batch of episode samples. """ data = [input_ids, attention_mask] episode_embeddings,_ = self.get_episode_embeddings(data) return episode_embeddings def _model_forward(self, batch): """Passes a batch of data through the model. This is used in the lightning_trainer.py file. """ data, _, _ = batch episode_embeddings, comment_embeddings = self.forward(data) # labels = torch.flatten(labels) return episode_embeddings, comment_embeddings