|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel |
|
from transformers.utils import ModelOutput |
|
|
|
|
|
@dataclass |
|
class TransformationModelOutput(ModelOutput): |
|
""" |
|
Base class for text model's outputs that also contains a pooling of the last hidden states. |
|
|
|
Args: |
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): |
|
The text embeddings obtained by applying the projection layer to the pooler_output. |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
projection_state: Optional[torch.FloatTensor] = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class RobertaSeriesConfig(XLMRobertaConfig): |
|
def __init__( |
|
self, |
|
pad_token_id=1, |
|
bos_token_id=0, |
|
eos_token_id=2, |
|
project_dim=512, |
|
pooler_fn="cls", |
|
learn_encoder=False, |
|
use_attention_mask=True, |
|
**kwargs, |
|
): |
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
|
self.project_dim = project_dim |
|
self.pooler_fn = pooler_fn |
|
self.learn_encoder = learn_encoder |
|
self.use_attention_mask = use_attention_mask |
|
|
|
|
|
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): |
|
_keys_to_ignore_on_load_unexpected = [r"pooler"] |
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
base_model_prefix = "roberta" |
|
config_class = RobertaSeriesConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.roberta = XLMRobertaModel(config) |
|
self.transformation = nn.Linear(config.hidden_size, config.project_dim) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
): |
|
r""" """ |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.base_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
projection_state = self.transformation(outputs.last_hidden_state) |
|
|
|
return TransformationModelOutput( |
|
projection_state=projection_state, |
|
last_hidden_state=outputs.last_hidden_state, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|