File size: 738 Bytes
0e4c95f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from transformers import RoFormerForMaskedLM


class RoFormerForSparseEmbeddingV2(RoFormerForMaskedLM):
    def forward(self, input_ids, attention_mask, return_sparse=False):
        logits = super().forward(input_ids, attention_mask)['logits']  # [B,L,V]
        token_mask = (1 - attention_mask.unsqueeze(-1)) * -1e4  # [B,L,1]
        token_mask[:, 0, :] = -1e4
        last_ind = torch.sum(attention_mask, -1, keepdim=True).unsqueeze(-1) - 1  # [B,1,1]
        token_mask = torch.scatter(token_mask, -2, last_ind, -1e4)
        logits = logits + token_mask
        emb = torch.log(1 + torch.max(torch.relu(logits), dim=-2).values)  # [B,V]

        if return_sparse:
            emb = emb.to_sparse()

        return emb