Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from easydict import EasyDict as edict | |
from xml.model_components import BertAttention, TrainablePositionalEncoding | |
class TextEncoder(nn.Module): | |
def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings): | |
super().__init__() | |
self.transformer_encoder = BertAttention(edict( | |
hidden_size=hidden_size, | |
intermediate_size=hidden_size, | |
hidden_dropout_prob=drop, | |
attention_probs_dropout_prob=drop, | |
num_attention_heads=nheads, | |
)) | |
self.pos_embed = TrainablePositionalEncoding( | |
max_position_embeddings=max_position_embeddings, | |
hidden_size=hidden_size, | |
dropout=input_drop, | |
) | |
self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False) | |
def forward(self, feat, mask): | |
""" | |
Args: | |
feat: (N, L, D=hidden_size) | |
mask: (N, L) with 1 indicates valid | |
Returns: | |
(N, D) | |
""" | |
feat = self.pos_embed(feat) # (N, L, D) | |
feat = self.transformer_encoder(feat, mask.unsqueeze(1)) | |
att_scores = self.modular_vector_mapping(feat) # (N, L, 1) | |
att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1) | |
pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) # (N, 2 or 1, D) | |
return pooled_feat.squeeze(1) | |
def mask_logits(target, mask): | |
return target * mask + (1 - mask) * (-1e10) | |
def build_text_encoder(args): | |
return TextEncoder( | |
hidden_size=args.hidden_dim, | |
drop=args.dropout, | |
input_drop=args.input_dropout, | |
nheads=args.nheads, | |
max_position_embeddings=args.max_q_l | |
) | |