UAR_scene / .ipynb_checkpoints /model-checkpoint.py
gasmichel's picture
Upload folder using huggingface_hub
ef7e93f verified
raw
history blame
3.83 kB
import os
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from transformers import AutoModel
import torch.nn.functional as F
# from models.layers import MemoryEfficientAttention, SelfAttention
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoModel, PreTrainedModel
from .config import LUARConfig
from huggingface_hub import PyTorchModelHubMixin
class UARScene(
nn.Module,
PyTorchModelHubMixin,
):
"""Defines the SBERT model.
"""
config_class = LUARConfig
def __init__(self, config):
super().__init__()
self.config = config
self.create_transformer()
self.linear = nn.Linear(self.hidden_size, self.config.embedding_size)
def attn_fn(self, k, q ,v) :
d_k = q.size(-1)
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
p_attn = F.softmax(scores, dim=-1)
return torch.matmul(p_attn, v)
def create_transformer(self):
"""Creates the Transformer model.
"""
self.transformer = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
self.hidden_size = self.transformer.config.hidden_size
self.num_attention_heads = self.transformer.config.num_attention_heads
self.dim_head = self.hidden_size // self.num_attention_heads
def mean_pooling(self, token_embeddings, attention_mask):
"""Mean Pooling as described in the SBERT paper.
"""
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
return sum_embeddings / sum_mask
def get_episode_embeddings(self, data):
"""Computes the Author Embedding.
"""
# batch_size, num_sample_per_author, episode_length
input_ids, attention_mask = data[0].unsqueeze(1), data[1].unsqueeze(1)
B, N, E, _ = input_ids.shape
input_ids = rearrange(input_ids, 'b n e l -> (b n e) l')
attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l')
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
output_hidden_states=True
)
# at this point, we're embedding individual "comments"
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
comment_embeddings = rearrange(comment_embeddings, '(b n e) l -> (b n) e l', b=B, n=N, e=E)
# aggregate individual comments embeddings into episode embeddings
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
episode_embeddings = self.linear(episode_embeddings)
return episode_embeddings, comment_embeddings
def forward(self, input_ids, attention_mask):
"""Calculates a fixed-length feature vector for a batch of episode samples.
"""
data = [input_ids, attention_mask]
episode_embeddings,_ = self.get_episode_embeddings(data)
return episode_embeddings
def _model_forward(self, batch):
"""Passes a batch of data through the model.
This is used in the lightning_trainer.py file.
"""
data, _, _ = batch
episode_embeddings, comment_embeddings = self.forward(data)
# labels = torch.flatten(labels)
return episode_embeddings, comment_embeddings