Update models.py
Browse files
models.py
CHANGED
@@ -5,7 +5,8 @@ from torch.nn import functional as F
|
|
5 |
|
6 |
from . import commons
|
7 |
from . import modules
|
8 |
-
from . import
|
|
|
9 |
|
10 |
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
@@ -194,7 +195,7 @@ class TextEncoder(nn.Module):
|
|
194 |
|
195 |
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
196 |
|
197 |
-
self.encoder_ssl =
|
198 |
hidden_channels,
|
199 |
filter_channels,
|
200 |
n_heads,
|
@@ -203,14 +204,14 @@ class TextEncoder(nn.Module):
|
|
203 |
p_dropout,
|
204 |
)
|
205 |
|
206 |
-
self.encoder_text =
|
207 |
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
208 |
)
|
209 |
self.text_embedding = nn.Embedding(len(symbols.symbols), hidden_channels)
|
210 |
|
211 |
self.mrte = MRTE()
|
212 |
|
213 |
-
self.encoder2 =
|
214 |
hidden_channels,
|
215 |
filter_channels,
|
216 |
n_heads,
|
@@ -757,7 +758,7 @@ class CodePredictor(nn.Module):
|
|
757 |
ssl_dim, style_vector_dim=hidden_channels
|
758 |
)
|
759 |
|
760 |
-
self.encoder =
|
761 |
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
762 |
)
|
763 |
|
|
|
5 |
|
6 |
from . import commons
|
7 |
from . import modules
|
8 |
+
from .modules import Log
|
9 |
+
from .attentions import Encoder
|
10 |
|
11 |
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
12 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
|
195 |
|
196 |
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
197 |
|
198 |
+
self.encoder_ssl = Encoder(
|
199 |
hidden_channels,
|
200 |
filter_channels,
|
201 |
n_heads,
|
|
|
204 |
p_dropout,
|
205 |
)
|
206 |
|
207 |
+
self.encoder_text = Encoder(
|
208 |
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
209 |
)
|
210 |
self.text_embedding = nn.Embedding(len(symbols.symbols), hidden_channels)
|
211 |
|
212 |
self.mrte = MRTE()
|
213 |
|
214 |
+
self.encoder2 = Encoder(
|
215 |
hidden_channels,
|
216 |
filter_channels,
|
217 |
n_heads,
|
|
|
758 |
ssl_dim, style_vector_dim=hidden_channels
|
759 |
)
|
760 |
|
761 |
+
self.encoder = Encoder(
|
762 |
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
763 |
)
|
764 |
|