| import torch | |
| from transformers import XLMRobertaModel as XLMRobertaModelBase | |
| class XLMRobertaModel(XLMRobertaModelBase): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.question_projection = torch.nn.Linear(768, 512) | |
| self.answer_projection = torch.nn.Linear(768, 512) | |
| def _embed(self, input_ids, attention_mask, projection): | |
| outputs = super().__call__(input_ids, attention_mask=attention_mask) | |
| sequence_output = outputs[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float() | |
| embeddings = torch.sum(sequence_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return torch.tanh(projection(embeddings)) | |
| def question(self, input_ids, attention_mask): | |
| return self._embed(input_ids, attention_mask, self.question_projection) | |
| def answer(self, input_ids, attention_mask): | |
| return self._embed(input_ids, attention_mask, self.answer_projection) |