Spaces:
Runtime error
Runtime error
File size: 1,609 Bytes
f4dac30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSA(nn.Module):
def __init__(self, attn_dim, kernel_size=31, filters=32):
super().__init__()
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
self.L = nn.Linear(filters, attn_dim, bias=False)
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
self.v = nn.Linear(attn_dim, 1, bias=False)
self.cumulative = None
self.attention = None
def init_attention(self, encoder_seq_proj):
device = encoder_seq_proj.device # use same device as parameters
b, t, c = encoder_seq_proj.size()
self.cumulative = torch.zeros(b, t, device=device)
self.attention = torch.zeros(b, t, device=device)
def forward(self, encoder_seq_proj, query, times, chars):
if times == 0: self.init_attention(encoder_seq_proj)
processed_query = self.W(query).unsqueeze(1)
location = self.cumulative.unsqueeze(1)
processed_loc = self.L(self.conv(location).transpose(1, 2))
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
u = u.squeeze(-1)
# Mask zero padding chars
u = u * (chars != 0).float()
# Smooth Attention
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
scores = F.softmax(u, dim=1)
self.attention = scores
self.cumulative = self.cumulative + self.attention
return scores.unsqueeze(-1).transpose(1, 2)
|