xcczach commited on
Commit
cec8e53
·
verified ·
1 Parent(s): a858803

Update attentions.py

Browse files
Files changed (1) hide show
  1. attentions.py +12 -12
attentions.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from torch import nn
4
  from torch.nn import functional as F
5
 
6
- from . import commons
7
  from .modules import LayerNorm
8
 
9
 
@@ -74,7 +74,7 @@ class Encoder(nn.Module):
74
  x = self.cond_pre(x)
75
  cond_offset = i * 2 * self.hidden_channels
76
  g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
77
- x = commons.fused_add_tanh_sigmoid_multiply(
78
  x, g_l, torch.IntTensor([self.hidden_channels])
79
  )
80
  y = self.attn_layers[i](x, x, attn_mask)
@@ -153,7 +153,7 @@ class Decoder(nn.Module):
153
  x: decoder input
154
  h: encoder output
155
  """
156
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
157
  device=x.device, dtype=x.dtype
158
  )
159
  encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
@@ -316,7 +316,7 @@ class MultiHeadAttention(nn.Module):
316
  if pad_length > 0:
317
  padded_relative_embeddings = F.pad(
318
  relative_embeddings,
319
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
320
  )
321
  else:
322
  padded_relative_embeddings = relative_embeddings
@@ -332,12 +332,12 @@ class MultiHeadAttention(nn.Module):
332
  """
333
  batch, heads, length, _ = x.size()
334
  # Concat columns of pad to shift from relative to absolute indexing.
335
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
336
 
337
  # Concat extra elements so to add up to shape (len+1, 2*len-1).
338
  x_flat = x.view([batch, heads, length * 2 * length])
339
  x_flat = F.pad(
340
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
341
  )
342
 
343
  # Reshape and slice out the padded elements.
@@ -354,11 +354,11 @@ class MultiHeadAttention(nn.Module):
354
  batch, heads, length, _ = x.size()
355
  # padd along column
356
  x = F.pad(
357
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
358
  )
359
  x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
360
  # add 0's in the beginning that will skew the elements after reshape
361
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
362
  x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
363
  return x_final
364
 
@@ -419,7 +419,7 @@ class FFN(nn.Module):
419
  pad_l = self.kernel_size - 1
420
  pad_r = 0
421
  padding = [[0, 0], [0, 0], [pad_l, pad_r]]
422
- x = F.pad(x, commons.convert_pad_shape(padding))
423
  return x
424
 
425
  def _same_padding(self, x):
@@ -428,7 +428,7 @@ class FFN(nn.Module):
428
  pad_l = (self.kernel_size - 1) // 2
429
  pad_r = self.kernel_size // 2
430
  padding = [[0, 0], [0, 0], [pad_l, pad_r]]
431
- x = F.pad(x, commons.convert_pad_shape(padding))
432
  return x
433
 
434
 
@@ -622,7 +622,7 @@ class FFT(nn.Module):
622
  if g is not None:
623
  g = self.cond_layer(g)
624
 
625
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
626
  device=x.device, dtype=x.dtype
627
  )
628
  x = x * x_mask
@@ -631,7 +631,7 @@ class FFT(nn.Module):
631
  x = self.cond_pre(x)
632
  cond_offset = i * 2 * self.hidden_channels
633
  g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
634
- x = commons.fused_add_tanh_sigmoid_multiply(
635
  x, g_l, torch.IntTensor([self.hidden_channels])
636
  )
637
  y = self.self_attn_layers[i](x, x, self_attn_mask)
 
3
  from torch import nn
4
  from torch.nn import functional as F
5
 
6
+ from .commons import fused_add_tanh_sigmoid_multiply, subsequent_mask,convert_pad_shape
7
  from .modules import LayerNorm
8
 
9
 
 
74
  x = self.cond_pre(x)
75
  cond_offset = i * 2 * self.hidden_channels
76
  g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
77
+ x = fused_add_tanh_sigmoid_multiply(
78
  x, g_l, torch.IntTensor([self.hidden_channels])
79
  )
80
  y = self.attn_layers[i](x, x, attn_mask)
 
153
  x: decoder input
154
  h: encoder output
155
  """
156
+ self_attn_mask = subsequent_mask(x_mask.size(2)).to(
157
  device=x.device, dtype=x.dtype
158
  )
159
  encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
 
316
  if pad_length > 0:
317
  padded_relative_embeddings = F.pad(
318
  relative_embeddings,
319
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
320
  )
321
  else:
322
  padded_relative_embeddings = relative_embeddings
 
332
  """
333
  batch, heads, length, _ = x.size()
334
  # Concat columns of pad to shift from relative to absolute indexing.
335
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
336
 
337
  # Concat extra elements so to add up to shape (len+1, 2*len-1).
338
  x_flat = x.view([batch, heads, length * 2 * length])
339
  x_flat = F.pad(
340
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
341
  )
342
 
343
  # Reshape and slice out the padded elements.
 
354
  batch, heads, length, _ = x.size()
355
  # padd along column
356
  x = F.pad(
357
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
358
  )
359
  x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
360
  # add 0's in the beginning that will skew the elements after reshape
361
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
362
  x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
363
  return x_final
364
 
 
419
  pad_l = self.kernel_size - 1
420
  pad_r = 0
421
  padding = [[0, 0], [0, 0], [pad_l, pad_r]]
422
+ x = F.pad(x, convert_pad_shape(padding))
423
  return x
424
 
425
  def _same_padding(self, x):
 
428
  pad_l = (self.kernel_size - 1) // 2
429
  pad_r = self.kernel_size // 2
430
  padding = [[0, 0], [0, 0], [pad_l, pad_r]]
431
+ x = F.pad(x, convert_pad_shape(padding))
432
  return x
433
 
434
 
 
622
  if g is not None:
623
  g = self.cond_layer(g)
624
 
625
+ self_attn_mask = subsequent_mask(x_mask.size(2)).to(
626
  device=x.device, dtype=x.dtype
627
  )
628
  x = x * x_mask
 
631
  x = self.cond_pre(x)
632
  cond_offset = i * 2 * self.hidden_channels
633
  g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
634
+ x = fused_add_tanh_sigmoid_multiply(
635
  x, g_l, torch.IntTensor([self.hidden_channels])
636
  )
637
  y = self.self_attn_layers[i](x, x, self_attn_mask)