Update attentions.py
Browse files- attentions.py +12 -12
attentions.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
from torch import nn
|
4 |
from torch.nn import functional as F
|
5 |
|
6 |
-
from . import
|
7 |
from .modules import LayerNorm
|
8 |
|
9 |
|
@@ -74,7 +74,7 @@ class Encoder(nn.Module):
|
|
74 |
x = self.cond_pre(x)
|
75 |
cond_offset = i * 2 * self.hidden_channels
|
76 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
77 |
-
x =
|
78 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
79 |
)
|
80 |
y = self.attn_layers[i](x, x, attn_mask)
|
@@ -153,7 +153,7 @@ class Decoder(nn.Module):
|
|
153 |
x: decoder input
|
154 |
h: encoder output
|
155 |
"""
|
156 |
-
self_attn_mask =
|
157 |
device=x.device, dtype=x.dtype
|
158 |
)
|
159 |
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
@@ -316,7 +316,7 @@ class MultiHeadAttention(nn.Module):
|
|
316 |
if pad_length > 0:
|
317 |
padded_relative_embeddings = F.pad(
|
318 |
relative_embeddings,
|
319 |
-
|
320 |
)
|
321 |
else:
|
322 |
padded_relative_embeddings = relative_embeddings
|
@@ -332,12 +332,12 @@ class MultiHeadAttention(nn.Module):
|
|
332 |
"""
|
333 |
batch, heads, length, _ = x.size()
|
334 |
# Concat columns of pad to shift from relative to absolute indexing.
|
335 |
-
x = F.pad(x,
|
336 |
|
337 |
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
338 |
x_flat = x.view([batch, heads, length * 2 * length])
|
339 |
x_flat = F.pad(
|
340 |
-
x_flat,
|
341 |
)
|
342 |
|
343 |
# Reshape and slice out the padded elements.
|
@@ -354,11 +354,11 @@ class MultiHeadAttention(nn.Module):
|
|
354 |
batch, heads, length, _ = x.size()
|
355 |
# padd along column
|
356 |
x = F.pad(
|
357 |
-
x,
|
358 |
)
|
359 |
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
360 |
# add 0's in the beginning that will skew the elements after reshape
|
361 |
-
x_flat = F.pad(x_flat,
|
362 |
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
363 |
return x_final
|
364 |
|
@@ -419,7 +419,7 @@ class FFN(nn.Module):
|
|
419 |
pad_l = self.kernel_size - 1
|
420 |
pad_r = 0
|
421 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
422 |
-
x = F.pad(x,
|
423 |
return x
|
424 |
|
425 |
def _same_padding(self, x):
|
@@ -428,7 +428,7 @@ class FFN(nn.Module):
|
|
428 |
pad_l = (self.kernel_size - 1) // 2
|
429 |
pad_r = self.kernel_size // 2
|
430 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
431 |
-
x = F.pad(x,
|
432 |
return x
|
433 |
|
434 |
|
@@ -622,7 +622,7 @@ class FFT(nn.Module):
|
|
622 |
if g is not None:
|
623 |
g = self.cond_layer(g)
|
624 |
|
625 |
-
self_attn_mask =
|
626 |
device=x.device, dtype=x.dtype
|
627 |
)
|
628 |
x = x * x_mask
|
@@ -631,7 +631,7 @@ class FFT(nn.Module):
|
|
631 |
x = self.cond_pre(x)
|
632 |
cond_offset = i * 2 * self.hidden_channels
|
633 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
634 |
-
x =
|
635 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
636 |
)
|
637 |
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
|
|
3 |
from torch import nn
|
4 |
from torch.nn import functional as F
|
5 |
|
6 |
+
from .commons import fused_add_tanh_sigmoid_multiply, subsequent_mask,convert_pad_shape
|
7 |
from .modules import LayerNorm
|
8 |
|
9 |
|
|
|
74 |
x = self.cond_pre(x)
|
75 |
cond_offset = i * 2 * self.hidden_channels
|
76 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
77 |
+
x = fused_add_tanh_sigmoid_multiply(
|
78 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
79 |
)
|
80 |
y = self.attn_layers[i](x, x, attn_mask)
|
|
|
153 |
x: decoder input
|
154 |
h: encoder output
|
155 |
"""
|
156 |
+
self_attn_mask = subsequent_mask(x_mask.size(2)).to(
|
157 |
device=x.device, dtype=x.dtype
|
158 |
)
|
159 |
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
|
|
316 |
if pad_length > 0:
|
317 |
padded_relative_embeddings = F.pad(
|
318 |
relative_embeddings,
|
319 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
320 |
)
|
321 |
else:
|
322 |
padded_relative_embeddings = relative_embeddings
|
|
|
332 |
"""
|
333 |
batch, heads, length, _ = x.size()
|
334 |
# Concat columns of pad to shift from relative to absolute indexing.
|
335 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
336 |
|
337 |
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
338 |
x_flat = x.view([batch, heads, length * 2 * length])
|
339 |
x_flat = F.pad(
|
340 |
+
x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
341 |
)
|
342 |
|
343 |
# Reshape and slice out the padded elements.
|
|
|
354 |
batch, heads, length, _ = x.size()
|
355 |
# padd along column
|
356 |
x = F.pad(
|
357 |
+
x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
358 |
)
|
359 |
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
360 |
# add 0's in the beginning that will skew the elements after reshape
|
361 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
362 |
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
363 |
return x_final
|
364 |
|
|
|
419 |
pad_l = self.kernel_size - 1
|
420 |
pad_r = 0
|
421 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
422 |
+
x = F.pad(x, convert_pad_shape(padding))
|
423 |
return x
|
424 |
|
425 |
def _same_padding(self, x):
|
|
|
428 |
pad_l = (self.kernel_size - 1) // 2
|
429 |
pad_r = self.kernel_size // 2
|
430 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
431 |
+
x = F.pad(x, convert_pad_shape(padding))
|
432 |
return x
|
433 |
|
434 |
|
|
|
622 |
if g is not None:
|
623 |
g = self.cond_layer(g)
|
624 |
|
625 |
+
self_attn_mask = subsequent_mask(x_mask.size(2)).to(
|
626 |
device=x.device, dtype=x.dtype
|
627 |
)
|
628 |
x = x * x_mask
|
|
|
631 |
x = self.cond_pre(x)
|
632 |
cond_offset = i * 2 * self.hidden_channels
|
633 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
634 |
+
x = fused_add_tanh_sigmoid_multiply(
|
635 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
636 |
)
|
637 |
y = self.self_attn_layers[i](x, x, self_attn_mask)
|