|
import os |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, reduce, repeat |
|
from transformers import AutoModel |
|
import torch.nn.functional as F |
|
|
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
from transformers import AutoModel, PreTrainedModel |
|
from .config import LUARConfig |
|
from huggingface_hub import PyTorchModelHubMixin |
|
import math |
|
|
|
|
|
class UARScene(PreTrainedModel): |
|
"""Defines the SBERT model. |
|
""" |
|
config_class = LUARConfig |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
self.create_transformer() |
|
self.linear = nn.Linear(self.hidden_size, config.embedding_size) |
|
|
|
def attn_fn(self, k, q ,v) : |
|
d_k = q.size(-1) |
|
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) |
|
p_attn = F.softmax(scores, dim=-1) |
|
|
|
return torch.matmul(p_attn, v) |
|
|
|
|
|
def create_transformer(self): |
|
"""Creates the Transformer model. |
|
""" |
|
|
|
self.transformer = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1") |
|
self.hidden_size = self.transformer.config.hidden_size |
|
self.num_attention_heads = self.transformer.config.num_attention_heads |
|
self.dim_head = self.hidden_size // self.num_attention_heads |
|
|
|
def mean_pooling(self, token_embeddings, attention_mask): |
|
"""Mean Pooling as described in the SBERT paper. |
|
""" |
|
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float() |
|
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') |
|
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
def get_episode_embeddings(self, data): |
|
"""Computes the Author Embedding. |
|
""" |
|
|
|
input_ids, attention_mask = data[0].unsqueeze(1), data[1].unsqueeze(1) |
|
B, N, E, _ = input_ids.shape |
|
|
|
input_ids = rearrange(input_ids, 'b n e l -> (b n e) l') |
|
attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l') |
|
|
|
outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
output_hidden_states=True |
|
) |
|
|
|
|
|
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) |
|
comment_embeddings = rearrange(comment_embeddings, '(b n e) l -> (b n) e l', b=B, n=N, e=E) |
|
|
|
|
|
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) |
|
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') |
|
|
|
episode_embeddings = self.linear(episode_embeddings) |
|
return episode_embeddings, comment_embeddings |
|
|
|
def forward(self, input_ids, attention_mask): |
|
"""Calculates a fixed-length feature vector for a batch of episode samples. |
|
""" |
|
data = [input_ids, attention_mask] |
|
episode_embeddings,_ = self.get_episode_embeddings(data) |
|
|
|
return episode_embeddings |
|
|
|
def _model_forward(self, batch): |
|
"""Passes a batch of data through the model. |
|
This is used in the lightning_trainer.py file. |
|
""" |
|
data, _, _ = batch |
|
episode_embeddings, comment_embeddings = self.forward(data) |
|
|
|
|
|
return episode_embeddings, comment_embeddings |
|
|