xcczach commited on
Commit
adb8833
·
verified ·
1 Parent(s): 6bc80e7

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +6 -5
models.py CHANGED
@@ -5,7 +5,8 @@ from torch.nn import functional as F
5
 
6
  from . import commons
7
  from . import modules
8
- from . import attentions
 
9
 
10
  from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
@@ -194,7 +195,7 @@ class TextEncoder(nn.Module):
194
 
195
  self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
196
 
197
- self.encoder_ssl = attentions.Encoder(
198
  hidden_channels,
199
  filter_channels,
200
  n_heads,
@@ -203,14 +204,14 @@ class TextEncoder(nn.Module):
203
  p_dropout,
204
  )
205
 
206
- self.encoder_text = attentions.Encoder(
207
  hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
208
  )
209
  self.text_embedding = nn.Embedding(len(symbols.symbols), hidden_channels)
210
 
211
  self.mrte = MRTE()
212
 
213
- self.encoder2 = attentions.Encoder(
214
  hidden_channels,
215
  filter_channels,
216
  n_heads,
@@ -757,7 +758,7 @@ class CodePredictor(nn.Module):
757
  ssl_dim, style_vector_dim=hidden_channels
758
  )
759
 
760
- self.encoder = attentions.Encoder(
761
  hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
762
  )
763
 
 
5
 
6
  from . import commons
7
  from . import modules
8
+ from .modules import Log
9
+ from .attentions import Encoder
10
 
11
  from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
 
195
 
196
  self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
197
 
198
+ self.encoder_ssl = Encoder(
199
  hidden_channels,
200
  filter_channels,
201
  n_heads,
 
204
  p_dropout,
205
  )
206
 
207
+ self.encoder_text = Encoder(
208
  hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
209
  )
210
  self.text_embedding = nn.Embedding(len(symbols.symbols), hidden_channels)
211
 
212
  self.mrte = MRTE()
213
 
214
+ self.encoder2 = Encoder(
215
  hidden_channels,
216
  filter_channels,
217
  n_heads,
 
758
  ssl_dim, style_vector_dim=hidden_channels
759
  )
760
 
761
+ self.encoder = Encoder(
762
  hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
763
  )
764