Update attentions.py
Browse files- attentions.py +12 -12
attentions.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch
|
|
| 3 |
from torch import nn
|
| 4 |
from torch.nn import functional as F
|
| 5 |
|
| 6 |
-
from . import
|
| 7 |
from .modules import LayerNorm
|
| 8 |
|
| 9 |
|
|
@@ -74,7 +74,7 @@ class Encoder(nn.Module):
|
|
| 74 |
x = self.cond_pre(x)
|
| 75 |
cond_offset = i * 2 * self.hidden_channels
|
| 76 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 77 |
-
x =
|
| 78 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
| 79 |
)
|
| 80 |
y = self.attn_layers[i](x, x, attn_mask)
|
|
@@ -153,7 +153,7 @@ class Decoder(nn.Module):
|
|
| 153 |
x: decoder input
|
| 154 |
h: encoder output
|
| 155 |
"""
|
| 156 |
-
self_attn_mask =
|
| 157 |
device=x.device, dtype=x.dtype
|
| 158 |
)
|
| 159 |
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
|
@@ -316,7 +316,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 316 |
if pad_length > 0:
|
| 317 |
padded_relative_embeddings = F.pad(
|
| 318 |
relative_embeddings,
|
| 319 |
-
|
| 320 |
)
|
| 321 |
else:
|
| 322 |
padded_relative_embeddings = relative_embeddings
|
|
@@ -332,12 +332,12 @@ class MultiHeadAttention(nn.Module):
|
|
| 332 |
"""
|
| 333 |
batch, heads, length, _ = x.size()
|
| 334 |
# Concat columns of pad to shift from relative to absolute indexing.
|
| 335 |
-
x = F.pad(x,
|
| 336 |
|
| 337 |
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 338 |
x_flat = x.view([batch, heads, length * 2 * length])
|
| 339 |
x_flat = F.pad(
|
| 340 |
-
x_flat,
|
| 341 |
)
|
| 342 |
|
| 343 |
# Reshape and slice out the padded elements.
|
|
@@ -354,11 +354,11 @@ class MultiHeadAttention(nn.Module):
|
|
| 354 |
batch, heads, length, _ = x.size()
|
| 355 |
# padd along column
|
| 356 |
x = F.pad(
|
| 357 |
-
x,
|
| 358 |
)
|
| 359 |
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 360 |
# add 0's in the beginning that will skew the elements after reshape
|
| 361 |
-
x_flat = F.pad(x_flat,
|
| 362 |
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 363 |
return x_final
|
| 364 |
|
|
@@ -419,7 +419,7 @@ class FFN(nn.Module):
|
|
| 419 |
pad_l = self.kernel_size - 1
|
| 420 |
pad_r = 0
|
| 421 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 422 |
-
x = F.pad(x,
|
| 423 |
return x
|
| 424 |
|
| 425 |
def _same_padding(self, x):
|
|
@@ -428,7 +428,7 @@ class FFN(nn.Module):
|
|
| 428 |
pad_l = (self.kernel_size - 1) // 2
|
| 429 |
pad_r = self.kernel_size // 2
|
| 430 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 431 |
-
x = F.pad(x,
|
| 432 |
return x
|
| 433 |
|
| 434 |
|
|
@@ -622,7 +622,7 @@ class FFT(nn.Module):
|
|
| 622 |
if g is not None:
|
| 623 |
g = self.cond_layer(g)
|
| 624 |
|
| 625 |
-
self_attn_mask =
|
| 626 |
device=x.device, dtype=x.dtype
|
| 627 |
)
|
| 628 |
x = x * x_mask
|
|
@@ -631,7 +631,7 @@ class FFT(nn.Module):
|
|
| 631 |
x = self.cond_pre(x)
|
| 632 |
cond_offset = i * 2 * self.hidden_channels
|
| 633 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 634 |
-
x =
|
| 635 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
| 636 |
)
|
| 637 |
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
|
|
|
| 3 |
from torch import nn
|
| 4 |
from torch.nn import functional as F
|
| 5 |
|
| 6 |
+
from .commons import fused_add_tanh_sigmoid_multiply, subsequent_mask,convert_pad_shape
|
| 7 |
from .modules import LayerNorm
|
| 8 |
|
| 9 |
|
|
|
|
| 74 |
x = self.cond_pre(x)
|
| 75 |
cond_offset = i * 2 * self.hidden_channels
|
| 76 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 77 |
+
x = fused_add_tanh_sigmoid_multiply(
|
| 78 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
| 79 |
)
|
| 80 |
y = self.attn_layers[i](x, x, attn_mask)
|
|
|
|
| 153 |
x: decoder input
|
| 154 |
h: encoder output
|
| 155 |
"""
|
| 156 |
+
self_attn_mask = subsequent_mask(x_mask.size(2)).to(
|
| 157 |
device=x.device, dtype=x.dtype
|
| 158 |
)
|
| 159 |
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
|
|
|
| 316 |
if pad_length > 0:
|
| 317 |
padded_relative_embeddings = F.pad(
|
| 318 |
relative_embeddings,
|
| 319 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 320 |
)
|
| 321 |
else:
|
| 322 |
padded_relative_embeddings = relative_embeddings
|
|
|
|
| 332 |
"""
|
| 333 |
batch, heads, length, _ = x.size()
|
| 334 |
# Concat columns of pad to shift from relative to absolute indexing.
|
| 335 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 336 |
|
| 337 |
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 338 |
x_flat = x.view([batch, heads, length * 2 * length])
|
| 339 |
x_flat = F.pad(
|
| 340 |
+
x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 341 |
)
|
| 342 |
|
| 343 |
# Reshape and slice out the padded elements.
|
|
|
|
| 354 |
batch, heads, length, _ = x.size()
|
| 355 |
# padd along column
|
| 356 |
x = F.pad(
|
| 357 |
+
x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 358 |
)
|
| 359 |
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 360 |
# add 0's in the beginning that will skew the elements after reshape
|
| 361 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 362 |
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 363 |
return x_final
|
| 364 |
|
|
|
|
| 419 |
pad_l = self.kernel_size - 1
|
| 420 |
pad_r = 0
|
| 421 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 422 |
+
x = F.pad(x, convert_pad_shape(padding))
|
| 423 |
return x
|
| 424 |
|
| 425 |
def _same_padding(self, x):
|
|
|
|
| 428 |
pad_l = (self.kernel_size - 1) // 2
|
| 429 |
pad_r = self.kernel_size // 2
|
| 430 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 431 |
+
x = F.pad(x, convert_pad_shape(padding))
|
| 432 |
return x
|
| 433 |
|
| 434 |
|
|
|
|
| 622 |
if g is not None:
|
| 623 |
g = self.cond_layer(g)
|
| 624 |
|
| 625 |
+
self_attn_mask = subsequent_mask(x_mask.size(2)).to(
|
| 626 |
device=x.device, dtype=x.dtype
|
| 627 |
)
|
| 628 |
x = x * x_mask
|
|
|
|
| 631 |
x = self.cond_pre(x)
|
| 632 |
cond_offset = i * 2 * self.hidden_channels
|
| 633 |
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 634 |
+
x = fused_add_tanh_sigmoid_multiply(
|
| 635 |
x, g_l, torch.IntTensor([self.hidden_channels])
|
| 636 |
)
|
| 637 |
y = self.self_attn_layers[i](x, x, self_attn_mask)
|