SFMTS commited on
Commit
497b0f7
·
verified ·
1 Parent(s): a07122f

Upload 9 files

Browse files
DUR_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ce46bc0fdc9188d555f17432387a1336aec7511e59a5b50d1e98de1d6c2c09d
3
+ size 1124100
D_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0150bb3e70144be31a4faa57a7d2e80ed6427e5cb6c30dc39b23debf53c6fcf2
3
+ size 187270328
G_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83ee428db6803c667c067f8d98c0db42c4c3e6711fa5ce584789eb84793c738e
3
+ size 116087820
WD_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b7279846e957cfd392fcd4ebabf0668ced2cd9526b4d811308b4629130858ce
3
+ size 4695736
config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "v1",
3
+ "train": {
4
+ "log_interval": 200,
5
+ "eval_interval": 1000,
6
+ "seed": 42,
7
+ "epochs": 10000,
8
+ "learning_rate": 0.0001,
9
+ "betas": [
10
+ 0.8,
11
+ 0.99
12
+ ],
13
+ "eps": 1e-09,
14
+ "batch_size": 1,
15
+ "bf16_run": false,
16
+ "fp16_run": false,
17
+ "lr_decay": 0.99996,
18
+ "segment_size": 16384,
19
+ "init_lr_ratio": 1,
20
+ "warmup_epochs": 0,
21
+ "c_mel": 45,
22
+ "c_kl": 1.0,
23
+ "c_commit": 100,
24
+ "skip_optimizer": false,
25
+ "freeze_ZH_bert": false,
26
+ "freeze_JP_bert": false,
27
+ "freeze_EN_bert": false,
28
+ "freeze_emo": false,
29
+ "freeze_style": false,
30
+ "freeze_decoder": false
31
+ },
32
+ "data": {
33
+ "use_jp_extra": false,
34
+ "training_files": "Data\\v1\\train.list",
35
+ "validation_files": "Data\\v1\\val.list",
36
+ "max_wav_value": 32768.0,
37
+ "sampling_rate": 44100,
38
+ "filter_length": 2048,
39
+ "hop_length": 512,
40
+ "win_length": 1024,
41
+ "n_mel_channels": 64,
42
+ "mel_fmin": 0.0,
43
+ "mel_fmax": null,
44
+ "add_blank": true,
45
+ "n_speakers": 1,
46
+ "cleaned_text": true,
47
+ "spk2id": {
48
+ "test": 0,
49
+ }
50
+ },
51
+ "model": {
52
+ "use_spk_conditioned_encoder": true,
53
+ "use_noise_scaled_mas": true,
54
+ "use_mel_posterior_encoder": true,
55
+ "use_duration_discriminator": true,
56
+ "use_wavlm_discriminator": true,
57
+ "inter_channels": 128,
58
+ "hidden_channels": 128,
59
+ "filter_channels": 512,
60
+ "n_heads": 2,
61
+ "n_layers": 4,
62
+ "kernel_size": 3,
63
+ "p_dropout": 0.1,
64
+ "resblock": "1",
65
+ "resblock_kernel_sizes": [
66
+ 3,
67
+ 7,
68
+ 11
69
+ ],
70
+ "resblock_dilation_sizes": [
71
+ [
72
+ 1,
73
+ 3,
74
+ 5
75
+ ],
76
+ [
77
+ 1,
78
+ 3,
79
+ 5
80
+ ],
81
+ [
82
+ 1,
83
+ 3,
84
+ 5
85
+ ]
86
+ ],
87
+ "upsample_rates": [
88
+ 8,
89
+ 8,
90
+ 2,
91
+ 2,
92
+ 2
93
+ ],
94
+ "upsample_initial_channel": 256,
95
+ "upsample_kernel_sizes": [
96
+ 16,
97
+ 16,
98
+ 8,
99
+ 2,
100
+ 2
101
+ ],
102
+ "n_layers_q": 3,
103
+ "use_spectral_norm": true,
104
+ "gin_channels": 256,
105
+ "slm": {
106
+ "model": "./slm/wavlm-base-plus",
107
+ "sr": 16000,
108
+ "hidden": 768,
109
+ "nlayers": 13,
110
+ "initial_channel": 64
111
+ }
112
+ },
113
+ "version": "2.6.1"
114
+ }
models/models.py ADDED
@@ -0,0 +1,1189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment
11
+ from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS
12
+
13
+
14
+ class DurationDiscriminator(nn.Module): # vits2
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ filter_channels: int,
19
+ kernel_size: int,
20
+ p_dropout: float,
21
+ gin_channels: int = 0,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.filter_channels = filter_channels
27
+ self.kernel_size = kernel_size
28
+ self.p_dropout = p_dropout
29
+ self.gin_channels = gin_channels
30
+
31
+ self.drop = nn.Dropout(p_dropout)
32
+ self.conv_1 = nn.Conv1d(
33
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_1 = modules.LayerNorm(filter_channels)
36
+ self.conv_2 = nn.Conv1d(
37
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
38
+ )
39
+ self.norm_2 = modules.LayerNorm(filter_channels)
40
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
41
+
42
+ self.pre_out_conv_1 = nn.Conv1d(
43
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
+ )
45
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
46
+ self.pre_out_conv_2 = nn.Conv1d(
47
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
48
+ )
49
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
50
+
51
+ if gin_channels != 0:
52
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
53
+
54
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
55
+
56
+ def forward_probability(
57
+ self,
58
+ x: torch.Tensor,
59
+ x_mask: torch.Tensor,
60
+ dur: torch.Tensor,
61
+ g: Optional[torch.Tensor] = None,
62
+ ) -> torch.Tensor:
63
+ dur = self.dur_proj(dur)
64
+ x = torch.cat([x, dur], dim=1)
65
+ x = self.pre_out_conv_1(x * x_mask)
66
+ x = torch.relu(x)
67
+ x = self.pre_out_norm_1(x)
68
+ x = self.drop(x)
69
+ x = self.pre_out_conv_2(x * x_mask)
70
+ x = torch.relu(x)
71
+ x = self.pre_out_norm_2(x)
72
+ x = self.drop(x)
73
+ x = x * x_mask
74
+ x = x.transpose(1, 2)
75
+ output_prob = self.output_layer(x)
76
+ return output_prob
77
+
78
+ def forward(
79
+ self,
80
+ x: torch.Tensor,
81
+ x_mask: torch.Tensor,
82
+ dur_r: torch.Tensor,
83
+ dur_hat: torch.Tensor,
84
+ g: Optional[torch.Tensor] = None,
85
+ ) -> list[torch.Tensor]:
86
+ x = torch.detach(x)
87
+ if g is not None:
88
+ g = torch.detach(g)
89
+ x = x + self.cond(g)
90
+ x = self.conv_1(x * x_mask)
91
+ x = torch.relu(x)
92
+ x = self.norm_1(x)
93
+ x = self.drop(x)
94
+ x = self.conv_2(x * x_mask)
95
+ x = torch.relu(x)
96
+ x = self.norm_2(x)
97
+ x = self.drop(x)
98
+
99
+ output_probs = []
100
+ for dur in [dur_r, dur_hat]:
101
+ output_prob = self.forward_probability(x, x_mask, dur, g)
102
+ output_probs.append(output_prob)
103
+
104
+ return output_probs
105
+
106
+
107
+ class TransformerCouplingBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ channels: int,
111
+ hidden_channels: int,
112
+ filter_channels: int,
113
+ n_heads: int,
114
+ n_layers: int,
115
+ kernel_size: int,
116
+ p_dropout: float,
117
+ n_flows: int = 4,
118
+ gin_channels: int = 0,
119
+ share_parameter: bool = False,
120
+ ) -> None:
121
+ super().__init__()
122
+ self.channels = channels
123
+ self.hidden_channels = hidden_channels
124
+ self.kernel_size = kernel_size
125
+ self.n_layers = n_layers
126
+ self.n_flows = n_flows
127
+ self.gin_channels = gin_channels
128
+
129
+ self.flows = nn.ModuleList()
130
+
131
+ self.wn = (
132
+ # attentions.FFT(
133
+ # hidden_channels,
134
+ # filter_channels,
135
+ # n_heads,
136
+ # n_layers,
137
+ # kernel_size,
138
+ # p_dropout,
139
+ # isflow=True,
140
+ # gin_channels=self.gin_channels,
141
+ # )
142
+ None
143
+ if share_parameter
144
+ else None
145
+ )
146
+
147
+ for i in range(n_flows):
148
+ self.flows.append(
149
+ modules.TransformerCouplingLayer(
150
+ channels,
151
+ hidden_channels,
152
+ kernel_size,
153
+ n_layers,
154
+ n_heads,
155
+ p_dropout,
156
+ filter_channels,
157
+ mean_only=True,
158
+ wn_sharing_parameter=self.wn,
159
+ gin_channels=self.gin_channels,
160
+ )
161
+ )
162
+ self.flows.append(modules.Flip())
163
+
164
+ def forward(
165
+ self,
166
+ x: torch.Tensor,
167
+ x_mask: torch.Tensor,
168
+ g: Optional[torch.Tensor] = None,
169
+ reverse: bool = False,
170
+ ) -> torch.Tensor:
171
+ if not reverse:
172
+ for flow in self.flows:
173
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
174
+ else:
175
+ for flow in reversed(self.flows):
176
+ x = flow(x, x_mask, g=g, reverse=reverse)
177
+ return x
178
+
179
+
180
+ class StochasticDurationPredictor(nn.Module):
181
+ def __init__(
182
+ self,
183
+ in_channels: int,
184
+ filter_channels: int,
185
+ kernel_size: int,
186
+ p_dropout: float,
187
+ n_flows: int = 4,
188
+ gin_channels: int = 0,
189
+ ) -> None:
190
+ super().__init__()
191
+ filter_channels = in_channels # it needs to be removed from future version.
192
+ self.in_channels = in_channels
193
+ self.filter_channels = filter_channels
194
+ self.kernel_size = kernel_size
195
+ self.p_dropout = p_dropout
196
+ self.n_flows = n_flows
197
+ self.gin_channels = gin_channels
198
+
199
+ self.log_flow = modules.Log()
200
+ self.flows = nn.ModuleList()
201
+ self.flows.append(modules.ElementwiseAffine(2))
202
+ for i in range(n_flows):
203
+ self.flows.append(
204
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
205
+ )
206
+ self.flows.append(modules.Flip())
207
+
208
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
209
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
210
+ self.post_convs = modules.DDSConv(
211
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
212
+ )
213
+ self.post_flows = nn.ModuleList()
214
+ self.post_flows.append(modules.ElementwiseAffine(2))
215
+ for i in range(4):
216
+ self.post_flows.append(
217
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
218
+ )
219
+ self.post_flows.append(modules.Flip())
220
+
221
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
222
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
223
+ self.convs = modules.DDSConv(
224
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
225
+ )
226
+ if gin_channels != 0:
227
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
228
+
229
+ def forward(
230
+ self,
231
+ x: torch.Tensor,
232
+ x_mask: torch.Tensor,
233
+ w: Optional[torch.Tensor] = None,
234
+ g: Optional[torch.Tensor] = None,
235
+ reverse: bool = False,
236
+ noise_scale: float = 1.0,
237
+ ) -> torch.Tensor:
238
+ x = torch.detach(x)
239
+ x = self.pre(x)
240
+ if g is not None:
241
+ g = torch.detach(g)
242
+ x = x + self.cond(g)
243
+ x = self.convs(x, x_mask)
244
+ x = self.proj(x) * x_mask
245
+
246
+ if not reverse:
247
+ flows = self.flows
248
+ assert w is not None
249
+
250
+ logdet_tot_q = 0
251
+ h_w = self.post_pre(w)
252
+ h_w = self.post_convs(h_w, x_mask)
253
+ h_w = self.post_proj(h_w) * x_mask
254
+ e_q = (
255
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
256
+ * x_mask
257
+ )
258
+ z_q = e_q
259
+ for flow in self.post_flows:
260
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
261
+ logdet_tot_q += logdet_q
262
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
263
+ u = torch.sigmoid(z_u) * x_mask
264
+ z0 = (w - u) * x_mask
265
+ logdet_tot_q += torch.sum(
266
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
267
+ )
268
+ logq = (
269
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
270
+ - logdet_tot_q
271
+ )
272
+
273
+ logdet_tot = 0
274
+ z0, logdet = self.log_flow(z0, x_mask)
275
+ logdet_tot += logdet
276
+ z = torch.cat([z0, z1], 1)
277
+ for flow in flows:
278
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
279
+ logdet_tot = logdet_tot + logdet
280
+ nll = (
281
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
282
+ - logdet_tot
283
+ )
284
+ return nll + logq # [b]
285
+ else:
286
+ flows = list(reversed(self.flows))
287
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
288
+ z = (
289
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
290
+ * noise_scale
291
+ )
292
+ for flow in flows:
293
+ z = flow(z, x_mask, g=x, reverse=reverse)
294
+ z0, z1 = torch.split(z, [1, 1], 1)
295
+ logw = z0
296
+ return logw
297
+
298
+
299
+ class DurationPredictor(nn.Module):
300
+ def __init__(
301
+ self,
302
+ in_channels: int,
303
+ filter_channels: int,
304
+ kernel_size: int,
305
+ p_dropout: float,
306
+ gin_channels: int = 0,
307
+ ) -> None:
308
+ super().__init__()
309
+
310
+ self.in_channels = in_channels
311
+ self.filter_channels = filter_channels
312
+ self.kernel_size = kernel_size
313
+ self.p_dropout = p_dropout
314
+ self.gin_channels = gin_channels
315
+
316
+ self.drop = nn.Dropout(p_dropout)
317
+ self.conv_1 = nn.Conv1d(
318
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
319
+ )
320
+ self.norm_1 = modules.LayerNorm(filter_channels)
321
+ self.conv_2 = nn.Conv1d(
322
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
323
+ )
324
+ self.norm_2 = modules.LayerNorm(filter_channels)
325
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
326
+
327
+ if gin_channels != 0:
328
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
329
+
330
+ def forward(
331
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
332
+ ) -> torch.Tensor:
333
+ x = torch.detach(x)
334
+ if g is not None:
335
+ g = torch.detach(g)
336
+ x = x + self.cond(g)
337
+ x = self.conv_1(x * x_mask)
338
+ x = torch.relu(x)
339
+ x = self.norm_1(x)
340
+ x = self.drop(x)
341
+ x = self.conv_2(x * x_mask)
342
+ x = torch.relu(x)
343
+ x = self.norm_2(x)
344
+ x = self.drop(x)
345
+ x = self.proj(x * x_mask)
346
+ return x * x_mask
347
+
348
+
349
+ class Bottleneck(nn.Sequential):
350
+ def __init__(self, in_dim: int, hidden_dim: int) -> None:
351
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
352
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
353
+ super().__init__(c_fc1, c_fc2)
354
+
355
+
356
+ class Block(nn.Module):
357
+ def __init__(self, in_dim: int, hidden_dim: int) -> None:
358
+ super().__init__()
359
+ self.norm = nn.LayerNorm(in_dim)
360
+ self.mlp = MLP(in_dim, hidden_dim)
361
+
362
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
363
+ x = x + self.mlp(self.norm(x))
364
+ return x
365
+
366
+
367
+ class MLP(nn.Module):
368
+ def __init__(self, in_dim: int, hidden_dim: int) -> None:
369
+ super().__init__()
370
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
371
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
372
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
373
+
374
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
376
+ x = self.c_proj(x)
377
+ return x
378
+
379
+ class TextEncoder(nn.Module):
380
+ def __init__(
381
+ self,
382
+ n_vocab: int,
383
+ out_channels: int,
384
+ hidden_channels: int,
385
+ filter_channels: int,
386
+ n_heads: int,
387
+ n_layers: int,
388
+ kernel_size: int,
389
+ p_dropout: float,
390
+ n_speakers: int,
391
+ gin_channels: int = 0,
392
+ ) -> None:
393
+ super().__init__()
394
+ self.n_vocab = n_vocab
395
+ self.out_channels = out_channels
396
+ self.hidden_channels = hidden_channels
397
+ self.filter_channels = filter_channels
398
+ self.n_heads = n_heads
399
+ self.n_layers = n_layers
400
+ self.kernel_size = kernel_size
401
+ self.p_dropout = p_dropout
402
+ self.gin_channels = gin_channels
403
+ self.emb = nn.Embedding(len(SYMBOLS), hidden_channels)
404
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
405
+ self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels)
406
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
407
+ self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels)
408
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
409
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
410
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
411
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
412
+ self.style_proj = nn.Linear(256, hidden_channels)
413
+
414
+ self.encoder = attentions.Encoder(
415
+ hidden_channels,
416
+ filter_channels,
417
+ n_heads,
418
+ n_layers,
419
+ kernel_size,
420
+ p_dropout,
421
+ gin_channels=self.gin_channels,
422
+ )
423
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
424
+
425
+ def forward(
426
+ self,
427
+ x: torch.Tensor,
428
+ x_lengths: torch.Tensor,
429
+ tone: torch.Tensor,
430
+ language: torch.Tensor,
431
+ bert: torch.Tensor,
432
+ ja_bert: torch.Tensor,
433
+ en_bert: torch.Tensor,
434
+ style_vec: torch.Tensor,
435
+ sid: torch.Tensor,
436
+ g: Optional[torch.Tensor] = None,
437
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
438
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
439
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
440
+ en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
441
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
442
+
443
+ x = (
444
+ self.emb(x)
445
+ + self.tone_emb(tone)
446
+ + self.language_emb(language)
447
+ + bert_emb
448
+ + ja_bert_emb
449
+ + en_bert_emb
450
+ + style_emb
451
+ ) * math.sqrt(
452
+ self.hidden_channels
453
+ ) # [b, t, h]
454
+ x = torch.transpose(x, 1, -1) # [b, h, t]
455
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
456
+ x.dtype
457
+ )
458
+
459
+ x = self.encoder(x * x_mask, x_mask, g=g)
460
+ stats = self.proj(x) * x_mask
461
+
462
+ m, logs = torch.split(stats, self.out_channels, dim=1)
463
+ return x, m, logs, x_mask
464
+
465
+
466
+ class ResidualCouplingBlock(nn.Module):
467
+ def __init__(
468
+ self,
469
+ channels: int,
470
+ hidden_channels: int,
471
+ kernel_size: int,
472
+ dilation_rate: int,
473
+ n_layers: int,
474
+ n_flows: int = 4,
475
+ gin_channels: int = 0,
476
+ ) -> None:
477
+ super().__init__()
478
+ self.channels = channels
479
+ self.hidden_channels = hidden_channels
480
+ self.kernel_size = kernel_size
481
+ self.dilation_rate = dilation_rate
482
+ self.n_layers = n_layers
483
+ self.n_flows = n_flows
484
+ self.gin_channels = gin_channels
485
+
486
+ self.flows = nn.ModuleList()
487
+ for i in range(n_flows):
488
+ self.flows.append(
489
+ modules.ResidualCouplingLayer(
490
+ channels,
491
+ hidden_channels,
492
+ kernel_size,
493
+ dilation_rate,
494
+ n_layers,
495
+ gin_channels=gin_channels,
496
+ mean_only=True,
497
+ )
498
+ )
499
+ self.flows.append(modules.Flip())
500
+
501
+ def forward(
502
+ self,
503
+ x: torch.Tensor,
504
+ x_mask: torch.Tensor,
505
+ g: Optional[torch.Tensor] = None,
506
+ reverse: bool = False,
507
+ ) -> torch.Tensor:
508
+ if not reverse:
509
+ for flow in self.flows:
510
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
511
+ else:
512
+ for flow in reversed(self.flows):
513
+ x = flow(x, x_mask, g=g, reverse=reverse)
514
+ return x
515
+
516
+
517
+ class PosteriorEncoder(nn.Module):
518
+ def __init__(
519
+ self,
520
+ in_channels: int,
521
+ out_channels: int,
522
+ hidden_channels: int,
523
+ kernel_size: int,
524
+ dilation_rate: int,
525
+ n_layers: int,
526
+ gin_channels: int = 0,
527
+ ) -> None:
528
+ super().__init__()
529
+ self.in_channels = in_channels
530
+ self.out_channels = out_channels
531
+ self.hidden_channels = hidden_channels
532
+ self.kernel_size = kernel_size
533
+ self.dilation_rate = dilation_rate
534
+ self.n_layers = n_layers
535
+ self.gin_channels = gin_channels
536
+
537
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
538
+ self.enc = modules.WN(
539
+ hidden_channels,
540
+ kernel_size,
541
+ dilation_rate,
542
+ n_layers,
543
+ gin_channels=gin_channels,
544
+ )
545
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
546
+
547
+ def forward(
548
+ self,
549
+ x: torch.Tensor,
550
+ x_lengths: torch.Tensor,
551
+ g: Optional[torch.Tensor] = None,
552
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
553
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
554
+ x.dtype
555
+ )
556
+ x = self.pre(x) * x_mask
557
+ x = self.enc(x, x_mask, g=g)
558
+ stats = self.proj(x) * x_mask
559
+ m, logs = torch.split(stats, self.out_channels, dim=1)
560
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
561
+ return z, m, logs, x_mask
562
+
563
+
564
+ class Generator(torch.nn.Module):
565
+ def __init__(
566
+ self,
567
+ initial_channel: int,
568
+ resblock_str: str,
569
+ resblock_kernel_sizes: list[int],
570
+ resblock_dilation_sizes: list[list[int]],
571
+ upsample_rates: list[int],
572
+ upsample_initial_channel: int,
573
+ upsample_kernel_sizes: list[int],
574
+ gin_channels: int = 0,
575
+ ) -> None:
576
+ super(Generator, self).__init__()
577
+ self.num_kernels = len(resblock_kernel_sizes)
578
+ self.num_upsamples = len(upsample_rates)
579
+ self.conv_pre = Conv1d(
580
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
581
+ )
582
+ resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2
583
+
584
+ self.ups = nn.ModuleList()
585
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
586
+ self.ups.append(
587
+ weight_norm(
588
+ ConvTranspose1d(
589
+ upsample_initial_channel // (2**i),
590
+ upsample_initial_channel // (2 ** (i + 1)),
591
+ k,
592
+ u,
593
+ padding=(k - u) // 2,
594
+ )
595
+ )
596
+ )
597
+
598
+ self.resblocks = nn.ModuleList()
599
+ ch = None
600
+ for i in range(len(self.ups)):
601
+ ch = upsample_initial_channel // (2 ** (i + 1))
602
+ for j, (k, d) in enumerate(
603
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
604
+ ):
605
+ self.resblocks.append(resblock(ch, k, d)) # type: ignore
606
+
607
+ assert ch is not None
608
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
609
+ self.ups.apply(commons.init_weights)
610
+
611
+ if gin_channels != 0:
612
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
613
+
614
+ def forward(
615
+ self, x: torch.Tensor, g: Optional[torch.Tensor] = None
616
+ ) -> torch.Tensor:
617
+ x = self.conv_pre(x)
618
+ if g is not None:
619
+ x = x + self.cond(g)
620
+
621
+ for i in range(self.num_upsamples):
622
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
623
+ x = self.ups[i](x)
624
+ xs = None
625
+ for j in range(self.num_kernels):
626
+ if xs is None:
627
+ xs = self.resblocks[i * self.num_kernels + j](x)
628
+ else:
629
+ xs += self.resblocks[i * self.num_kernels + j](x)
630
+ assert xs is not None
631
+ x = xs / self.num_kernels
632
+ x = F.leaky_relu(x)
633
+ x = self.conv_post(x)
634
+ x = torch.tanh(x)
635
+
636
+ return x
637
+
638
+ def remove_weight_norm(self) -> None:
639
+ print("Removing weight norm...")
640
+ for layer in self.ups:
641
+ remove_weight_norm(layer)
642
+ for layer in self.resblocks:
643
+ layer.remove_weight_norm()
644
+
645
+
646
+ class DiscriminatorP(torch.nn.Module):
647
+ def __init__(
648
+ self,
649
+ period: int,
650
+ kernel_size: int = 5,
651
+ stride: int = 3,
652
+ use_spectral_norm: bool = False,
653
+ ) -> None:
654
+ super(DiscriminatorP, self).__init__()
655
+ self.period = period
656
+ self.use_spectral_norm = use_spectral_norm
657
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
658
+ self.convs = nn.ModuleList(
659
+ [
660
+ norm_f(
661
+ Conv2d(
662
+ 1,
663
+ 32,
664
+ (kernel_size, 1),
665
+ (stride, 1),
666
+ padding=(commons.get_padding(kernel_size, 1), 0),
667
+ )
668
+ ),
669
+ norm_f(
670
+ Conv2d(
671
+ 32,
672
+ 128,
673
+ (kernel_size, 1),
674
+ (stride, 1),
675
+ padding=(commons.get_padding(kernel_size, 1), 0),
676
+ )
677
+ ),
678
+ norm_f(
679
+ Conv2d(
680
+ 128,
681
+ 512,
682
+ (kernel_size, 1),
683
+ (stride, 1),
684
+ padding=(commons.get_padding(kernel_size, 1), 0),
685
+ )
686
+ ),
687
+ norm_f(
688
+ Conv2d(
689
+ 512,
690
+ 1024,
691
+ (kernel_size, 1),
692
+ (stride, 1),
693
+ padding=(commons.get_padding(kernel_size, 1), 0),
694
+ )
695
+ ),
696
+ norm_f(
697
+ Conv2d(
698
+ 1024,
699
+ 1024,
700
+ (kernel_size, 1),
701
+ 1,
702
+ padding=(commons.get_padding(kernel_size, 1), 0),
703
+ )
704
+ ),
705
+ ]
706
+ )
707
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
708
+
709
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
710
+ fmap = []
711
+
712
+ # 1d to 2d
713
+ b, c, t = x.shape
714
+ if t % self.period != 0: # pad first
715
+ n_pad = self.period - (t % self.period)
716
+ x = F.pad(x, (0, n_pad), "reflect")
717
+ t = t + n_pad
718
+ x = x.view(b, c, t // self.period, self.period)
719
+
720
+ for layer in self.convs:
721
+ x = layer(x)
722
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
723
+ fmap.append(x)
724
+ x = self.conv_post(x)
725
+ fmap.append(x)
726
+ x = torch.flatten(x, 1, -1)
727
+
728
+ return x, fmap
729
+
730
+
731
+ class DiscriminatorS(torch.nn.Module):
732
+ def __init__(self, use_spectral_norm: bool = False) -> None:
733
+ super(DiscriminatorS, self).__init__()
734
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
735
+ self.convs = nn.ModuleList(
736
+ [
737
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
738
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
739
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
740
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
741
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
742
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
743
+ ]
744
+ )
745
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
746
+
747
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
748
+ fmap = []
749
+
750
+ for layer in self.convs:
751
+ x = layer(x)
752
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
753
+ fmap.append(x)
754
+ x = self.conv_post(x)
755
+ fmap.append(x)
756
+ x = torch.flatten(x, 1, -1)
757
+
758
+ return x, fmap
759
+
760
+
761
+ class MultiPeriodDiscriminator(torch.nn.Module):
762
+ def __init__(self, use_spectral_norm: bool = False) -> None:
763
+ super(MultiPeriodDiscriminator, self).__init__()
764
+ periods = [2, 3, 5, 7, 11]
765
+
766
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
767
+ discs = discs + [
768
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
769
+ ]
770
+ self.discriminators = nn.ModuleList(discs)
771
+
772
+ def forward(
773
+ self,
774
+ y: torch.Tensor,
775
+ y_hat: torch.Tensor,
776
+ ) -> tuple[
777
+ list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]
778
+ ]:
779
+ y_d_rs = []
780
+ y_d_gs = []
781
+ fmap_rs = []
782
+ fmap_gs = []
783
+ for i, d in enumerate(self.discriminators):
784
+ y_d_r, fmap_r = d(y)
785
+ y_d_g, fmap_g = d(y_hat)
786
+ y_d_rs.append(y_d_r)
787
+ y_d_gs.append(y_d_g)
788
+ fmap_rs.append(fmap_r)
789
+ fmap_gs.append(fmap_g)
790
+
791
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
792
+
793
+
794
+ class WavLMDiscriminator(nn.Module):
795
+ """docstring for Discriminator."""
796
+
797
+ def __init__(
798
+ self,
799
+ slm_hidden: int = 768,
800
+ slm_layers: int = 13,
801
+ initial_channel: int = 64,
802
+ use_spectral_norm: bool = False,
803
+ ) -> None:
804
+ super(WavLMDiscriminator, self).__init__()
805
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
806
+ self.pre = norm_f(
807
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
808
+ )
809
+
810
+ self.convs = nn.ModuleList(
811
+ [
812
+ norm_f(
813
+ nn.Conv1d(
814
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
815
+ )
816
+ ),
817
+ norm_f(
818
+ nn.Conv1d(
819
+ initial_channel * 2,
820
+ initial_channel * 4,
821
+ kernel_size=5,
822
+ padding=2,
823
+ )
824
+ ),
825
+ norm_f(
826
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
827
+ ),
828
+ ]
829
+ )
830
+
831
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
832
+
833
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
834
+ x = self.pre(x)
835
+
836
+ fmap = []
837
+ for l in self.convs:
838
+ x = l(x)
839
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
840
+ fmap.append(x)
841
+ x = self.conv_post(x)
842
+ x = torch.flatten(x, 1, -1)
843
+
844
+ return x
845
+
846
+
847
+
848
+ class ReferenceEncoder(nn.Module):
849
+ """
850
+ inputs --- [N, Ty/r, n_mels*r] mels
851
+ outputs --- [N, ref_enc_gru_size]
852
+ """
853
+
854
+ def __init__(self, spec_channels: int, gin_channels: int = 0) -> None:
855
+ super().__init__()
856
+ self.spec_channels = spec_channels
857
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
858
+ K = len(ref_enc_filters)
859
+ filters = [1] + ref_enc_filters
860
+ convs = [
861
+ weight_norm(
862
+ nn.Conv2d(
863
+ in_channels=filters[i],
864
+ out_channels=filters[i + 1],
865
+ kernel_size=(3, 3),
866
+ stride=(2, 2),
867
+ padding=(1, 1),
868
+ )
869
+ )
870
+ for i in range(K)
871
+ ]
872
+ self.convs = nn.ModuleList(convs)
873
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
874
+
875
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
876
+ self.gru = nn.GRU(
877
+ input_size=ref_enc_filters[-1] * out_channels,
878
+ hidden_size=256 // 2,
879
+ batch_first=True,
880
+ )
881
+ self.proj = nn.Linear(128, gin_channels)
882
+
883
+ def forward(
884
+ self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
885
+ ) -> torch.Tensor:
886
+ N = inputs.size(0)
887
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
888
+ for conv in self.convs:
889
+ out = conv(out)
890
+ # out = wn(out)
891
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
892
+
893
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
894
+ T = out.size(1)
895
+ N = out.size(0)
896
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
897
+
898
+ self.gru.flatten_parameters()
899
+ memory, out = self.gru(out) # out --- [1, N, 128]
900
+
901
+ return self.proj(out.squeeze(0))
902
+
903
+ def calculate_channels(
904
+ self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
905
+ ) -> int:
906
+ for i in range(n_convs):
907
+ L = (L - kernel_size + 2 * pad) // stride + 1
908
+ return L
909
+
910
+
911
+ class SynthesizerTrn(nn.Module):
912
+ """
913
+ Synthesizer for Training
914
+ """
915
+
916
+ def __init__(
917
+ self,
918
+ n_vocab: int,
919
+ spec_channels: int,
920
+ segment_size: int,
921
+ inter_channels: int,
922
+ hidden_channels: int,
923
+ filter_channels: int,
924
+ n_heads: int,
925
+ n_layers: int,
926
+ kernel_size: int,
927
+ p_dropout: float,
928
+ resblock: str,
929
+ resblock_kernel_sizes: list[int],
930
+ resblock_dilation_sizes: list[list[int]],
931
+ upsample_rates: list[int],
932
+ upsample_initial_channel: int,
933
+ upsample_kernel_sizes: list[int],
934
+ n_speakers: int = 256,
935
+ gin_channels: int = 256,
936
+ use_sdp: bool = True,
937
+ n_flow_layer: int = 4,
938
+ n_layers_trans_flow: int = 6,
939
+ flow_share_parameter: bool = False,
940
+ use_transformer_flow: bool = True,
941
+ **kwargs: Any,
942
+ ) -> None:
943
+ super().__init__()
944
+ self.n_vocab = n_vocab
945
+ self.spec_channels = spec_channels
946
+ self.inter_channels = inter_channels
947
+ self.hidden_channels = hidden_channels
948
+ self.filter_channels = filter_channels
949
+ self.n_heads = n_heads
950
+ self.n_layers = n_layers
951
+ self.kernel_size = kernel_size
952
+ self.p_dropout = p_dropout
953
+ self.resblock = resblock
954
+ self.resblock_kernel_sizes = resblock_kernel_sizes
955
+ self.resblock_dilation_sizes = resblock_dilation_sizes
956
+ self.upsample_rates = upsample_rates
957
+ self.upsample_initial_channel = upsample_initial_channel
958
+ self.upsample_kernel_sizes = upsample_kernel_sizes
959
+ self.segment_size = segment_size
960
+ self.n_speakers = n_speakers
961
+ self.gin_channels = gin_channels
962
+ self.n_layers_trans_flow = n_layers_trans_flow
963
+ self.use_spk_conditioned_encoder = kwargs.get(
964
+ "use_spk_conditioned_encoder", True
965
+ )
966
+ self.use_sdp = use_sdp
967
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
968
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
969
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
970
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
971
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
972
+ self.enc_gin_channels = gin_channels
973
+ self.enc_p = TextEncoder(
974
+ n_vocab,
975
+ inter_channels,
976
+ hidden_channels,
977
+ filter_channels,
978
+ n_heads,
979
+ n_layers,
980
+ kernel_size,
981
+ p_dropout,
982
+ self.n_speakers,
983
+ gin_channels=self.enc_gin_channels,
984
+ )
985
+ self.dec = Generator(
986
+ inter_channels,
987
+ resblock,
988
+ resblock_kernel_sizes,
989
+ resblock_dilation_sizes,
990
+ upsample_rates,
991
+ upsample_initial_channel,
992
+ upsample_kernel_sizes,
993
+ gin_channels=gin_channels,
994
+ )
995
+ self.enc_q = PosteriorEncoder(
996
+ spec_channels,
997
+ inter_channels,
998
+ hidden_channels,
999
+ 5,
1000
+ 1,
1001
+ 16,
1002
+ gin_channels=gin_channels,
1003
+ )
1004
+ if use_transformer_flow:
1005
+ self.flow = TransformerCouplingBlock(
1006
+ inter_channels,
1007
+ hidden_channels,
1008
+ filter_channels,
1009
+ n_heads,
1010
+ n_layers_trans_flow,
1011
+ 5,
1012
+ p_dropout,
1013
+ n_flow_layer,
1014
+ gin_channels=gin_channels,
1015
+ share_parameter=flow_share_parameter,
1016
+ )
1017
+ else:
1018
+ self.flow = ResidualCouplingBlock(
1019
+ inter_channels,
1020
+ hidden_channels,
1021
+ 5,
1022
+ 1,
1023
+ n_flow_layer,
1024
+ gin_channels=gin_channels,
1025
+ )
1026
+ self.sdp = StochasticDurationPredictor(
1027
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
1028
+ )
1029
+ self.dp = DurationPredictor(
1030
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
1031
+ )
1032
+
1033
+ if n_speakers >= 1:
1034
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
1035
+ else:
1036
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
1037
+
1038
+ def forward(
1039
+ self,
1040
+ x: torch.Tensor,
1041
+ x_lengths: torch.Tensor,
1042
+ y: torch.Tensor,
1043
+ y_lengths: torch.Tensor,
1044
+ sid: torch.Tensor,
1045
+ tone: torch.Tensor,
1046
+ language: torch.Tensor,
1047
+ bert: torch.Tensor,
1048
+ ja_bert: torch.Tensor,
1049
+ en_bert: torch.Tensor,
1050
+ style_vec: torch.Tensor,
1051
+ ) -> tuple[
1052
+ torch.Tensor,
1053
+ torch.Tensor,
1054
+ torch.Tensor,
1055
+ torch.Tensor,
1056
+ torch.Tensor,
1057
+ torch.Tensor,
1058
+ torch.Tensor,
1059
+ tuple[torch.Tensor, ...],
1060
+ tuple[torch.Tensor, ...],
1061
+ ]:
1062
+ if self.n_speakers > 0:
1063
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1064
+ else:
1065
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1066
+ x, m_p, logs_p, x_mask = self.enc_p(
1067
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
1068
+ )
1069
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1070
+ z_p = self.flow(z, y_mask, g=g)
1071
+
1072
+ with torch.no_grad():
1073
+ # negative cross-entropy
1074
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
1075
+ neg_cent1 = torch.sum(
1076
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
1077
+ ) # [b, 1, t_s]
1078
+ neg_cent2 = torch.matmul(
1079
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
1080
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1081
+ neg_cent3 = torch.matmul(
1082
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
1083
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1084
+ neg_cent4 = torch.sum(
1085
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
1086
+ ) # [b, 1, t_s]
1087
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
1088
+ if self.use_noise_scaled_mas:
1089
+ epsilon = (
1090
+ torch.std(neg_cent)
1091
+ * torch.randn_like(neg_cent)
1092
+ * self.current_mas_noise_scale
1093
+ )
1094
+ neg_cent = neg_cent + epsilon
1095
+
1096
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1097
+ attn = (
1098
+ monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1))
1099
+ .unsqueeze(1)
1100
+ .detach()
1101
+ )
1102
+
1103
+ w = attn.sum(2)
1104
+
1105
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
1106
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
1107
+
1108
+ logw_ = torch.log(w + 1e-6) * x_mask
1109
+ logw = self.dp(x, x_mask, g=g)
1110
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1111
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1112
+ x_mask
1113
+ ) # for averaging
1114
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1115
+
1116
+ l_length = l_length_dp + l_length_sdp
1117
+
1118
+ # expand prior
1119
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1120
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1121
+
1122
+ z_slice, ids_slice = commons.rand_slice_segments(
1123
+ z, y_lengths, self.segment_size
1124
+ )
1125
+ o = self.dec(z_slice, g=g)
1126
+ return (
1127
+ o,
1128
+ l_length,
1129
+ attn,
1130
+ ids_slice,
1131
+ x_mask,
1132
+ y_mask,
1133
+ (z, z_p, m_p, logs_p, m_q, logs_q), # type: ignore
1134
+ (x, logw, logw_), # , logw_sdp),
1135
+ g,
1136
+ )
1137
+
1138
+
1139
+ def infer(
1140
+ self,
1141
+ x: torch.Tensor,
1142
+ x_lengths: torch.Tensor,
1143
+ sid: torch.Tensor,
1144
+ tone: torch.Tensor,
1145
+ language: torch.Tensor,
1146
+ bert: torch.Tensor,
1147
+ ja_bert: torch.Tensor,
1148
+ en_bert: torch.Tensor,
1149
+ style_vec: torch.Tensor,
1150
+ noise_scale: float = 0.667,
1151
+ length_scale: float = 1.0,
1152
+ noise_scale_w: float = 0.8,
1153
+ max_len: Optional[int] = None,
1154
+ sdp_ratio: float = 0.0,
1155
+ y: Optional[torch.Tensor] = None,
1156
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]:
1157
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1158
+ # g = self.gst(y)
1159
+ if self.n_speakers > 0:
1160
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1161
+ else:
1162
+ assert y is not None
1163
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1164
+ x, m_p, logs_p, x_mask = self.enc_p(
1165
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
1166
+ )
1167
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1168
+ sdp_ratio
1169
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1170
+ w = torch.exp(logw) * x_mask * length_scale
1171
+ w_ceil = torch.ceil(w)
1172
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1173
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1174
+ x_mask.dtype
1175
+ )
1176
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1177
+ attn = commons.generate_path(w_ceil, attn_mask)
1178
+
1179
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1180
+ 1, 2
1181
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1182
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1183
+ 1, 2
1184
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1185
+
1186
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1187
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1188
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1189
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
nlp/japanese/normalizer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 記号類の正規化変換マップの; : 「 」 括弧全般の扱いを変更
3
+ 記号類の正規化変換マップに、= < > # ^ *を追加
4
+
5
+
6
+ """
7
+
8
+ import re
9
+ import unicodedata
10
+
11
+ from num2words import num2words
12
+
13
+ from style_bert_vits2.nlp.symbols import PUNCTUATIONS
14
+
15
+
16
+ def normalize_text(text: str) -> str:
17
+ """
18
+ 日本語のテキストを正規化する。
19
+ 結果は、ちょうど次の文字のみからなる:
20
+ - ひらがな
21
+ - カタカナ(全角長音記号「ー」が入る!)
22
+ - 漢字
23
+ - 半角アルファベット(大文字と小文字)
24
+ - ギリシャ文字
25
+ - `.` (句点`。`や`…`の一部や改行等)
26
+ - `,` (読点`、`や`:`等)
27
+ - `?` (疑問符`?`)
28
+ - `!` (感嘆符`!`)
29
+ - `'` (`「`や`」`等)
30
+ - `-` (`―`(ダッシュ、長音記号ではない)や`-`等)
31
+
32
+ 注意点:
33
+ - 三点リーダー`…`は`...`に変換される(`なるほど…。` → `なるほど....`)
34
+ - 数字は漢字に変換される(`1,100円` → `千百円`、`52.34` → `五十二点三四`)
35
+ - 読点や疑問符等の位置・個数等は保持される(`??あ、、!!!` → `??あ,,!!!`)
36
+
37
+ Args:
38
+ text (str): 正規化するテキスト
39
+
40
+ Returns:
41
+ str: 正規化されたテキスト
42
+ """
43
+
44
+ res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる
45
+ res = __convert_numbers_to_words(res) # 「100円」→「百円」等
46
+ # 「~」と「〜」と「~」も長音記号として扱う
47
+ res = res.replace("~", "ー")
48
+ res = res.replace("~", "ー")
49
+ res = res.replace("〜", "ー")
50
+
51
+ res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除
52
+
53
+ # 結合文字の濁点・半濁点を削除
54
+ # 通常の「ば」等はそのままのこされる、「あ゛」は上で「あ゙」になりここで「あ」になる
55
+ res = res.replace("\u3099", "") # 結合文字の濁点を削除、る゙ → る
56
+ res = res.replace("\u309A", "") # 結合文字の半濁点を削除、な゚ → な
57
+ return res
58
+
59
+
60
+ def replace_punctuation(text: str) -> str:
61
+ """
62
+ 句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalk で読みが取得できるもののみ残す:
63
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
64
+
65
+ Args:
66
+ text (str): 正規化するテキスト
67
+
68
+ Returns:
69
+ str: 正規化されたテキスト
70
+ """
71
+
72
+ # 記号類の正規化変換マップ
73
+ REPLACE_MAP = {
74
+ ":": ":",
75
+ ";": ";",
76
+ ",": ",",
77
+ "。": ".",
78
+ "!": "!",
79
+ "?": "?",
80
+ "\n": ".",
81
+ ".": ".",
82
+ "…": "...",
83
+ "···": "...",
84
+ "・・・": "...",
85
+ "·": ",",
86
+ "・": ",",
87
+ "、": ",",
88
+ "$": ".",
89
+ "“": "'",
90
+ "”": "'",
91
+ '"': "'",
92
+ "‘": "'",
93
+ "’": "'",
94
+ "(": "(",
95
+ ")": ")",
96
+ "(": "(",
97
+ ")": ")",
98
+ "《": "(",
99
+ "》": ")",
100
+ "【": "(",
101
+ "】": ")",
102
+ "[": "(",
103
+ "]": ")",
104
+ # NFKC 正規化後のハイフン・ダッシュの変種を全て通常半角ハイフン - \u002d に変換
105
+ "\u02d7": "\u002d", # ˗, Modifier Letter Minus Sign
106
+ "\u2010": "\u002d", # ‐, Hyphen,
107
+ # "\u2011": "\u002d", # ‑, Non-Breaking Hyphen, NFKC により \u2010 に変換される
108
+ "\u2012": "\u002d", # ‒, Figure Dash
109
+ "\u2013": "\u002d", # –, En Dash
110
+ "\u2014": "\u002d", # —, Em Dash
111
+ "\u2015": "\u002d", # ―, Horizontal Bar
112
+ "\u2043": "\u002d", # ⁃, Hyphen Bullet
113
+ "\u2212": "\u002d", # −, Minus Sign
114
+ "\u23af": "\u002d", # ⎯, Horizontal Line Extension
115
+ "\u23e4": "\u002d", # ⏤, Straightness
116
+ "\u2500": "\u002d", # ─, Box Drawings Light Horizontal
117
+ "\u2501": "\u002d", # ━, Box Drawings Heavy Horizontal
118
+ "\u2e3a": "\u002d", # ⸺, Two-Em Dash
119
+ "\u2e3b": "\u002d", # ⸻, Three-Em Dash
120
+ # "~": "-", # これは長音記号「ー」として扱うよう変更
121
+ # "~": "-", # これも長音記号「ー」として扱うよう変更
122
+ "「": "'",
123
+ "」": "'",
124
+ "=": "=",
125
+ "<": "<",
126
+ ">": ">",
127
+ "#": "#",
128
+ "^": "^",
129
+ "*": "*",
130
+ }
131
+
132
+ pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys()))
133
+
134
+ # 句読点を辞書で置換
135
+ replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text)
136
+
137
+ replaced_text = re.sub(
138
+ # ↓ ひらがな、カタカナ、漢字
139
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
140
+ # ↓ 半角アルファベット(大文字と小文字)
141
+ + r"\u0041-\u005A\u0061-\u007A"
142
+ # ↓ 全角アルファベット(大文字と小文字)
143
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
144
+ # ↓ ギリシャ文字
145
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
146
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
147
+ + "".join(PUNCTUATIONS) + r"]+",
148
+ # 上述以外の文字を削除
149
+ "",
150
+ replaced_text,
151
+ )
152
+
153
+ return replaced_text
154
+
155
+
156
+ def __convert_numbers_to_words(text: str) -> str:
157
+ """
158
+ 記号や数字を日本語の文字表現に変換する。
159
+
160
+ Args:
161
+ text (str): 変換するテキスト
162
+
163
+ Returns:
164
+ str: 変換されたテキスト
165
+ """
166
+
167
+ NUMBER_WITH_SEPARATOR_PATTERN = re.compile("[0-9]{1,3}(,[0-9]{3})+")
168
+ CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
169
+ CURRENCY_PATTERN = re.compile(r"([$¥£€])([0-9.]*[0-9])")
170
+ NUMBER_PATTERN = re.compile(r"[0-9]+(\.[0-9]+)?")
171
+
172
+ res = NUMBER_WITH_SEPARATOR_PATTERN.sub(lambda m: m[0].replace(",", ""), text)
173
+ res = CURRENCY_PATTERN.sub(lambda m: m[2] + CURRENCY_MAP.get(m[1], m[1]), res)
174
+ res = NUMBER_PATTERN.sub(lambda m: num2words(m[0], lang="ja"), res)
175
+
176
+ return res
nlp/symbols.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PUNCTUATIONSに ":", ";", "=", "#", "<", ">", "^", "(", ")", "*"を追加
3
+
4
+ """
5
+
6
+ # Punctuations
7
+ PUNCTUATIONS = ["!", "?", "…", ",", ".", "'", "-", ":", ";", "=", "#", "<", ">", "^", "(", ")", "*"]
8
+
9
+ # Punctuations and special tokens
10
+ PUNCTUATION_SYMBOLS = PUNCTUATIONS + ["SP", "UNK"]
11
+
12
+ # Padding
13
+ PAD = "_"
14
+
15
+ # Chinese symbols
16
+ ZH_SYMBOLS = [
17
+ "E",
18
+ "En",
19
+ "a",
20
+ "ai",
21
+ "an",
22
+ "ang",
23
+ "ao",
24
+ "b",
25
+ "c",
26
+ "ch",
27
+ "d",
28
+ "e",
29
+ "ei",
30
+ "en",
31
+ "eng",
32
+ "er",
33
+ "f",
34
+ "g",
35
+ "h",
36
+ "i",
37
+ "i0",
38
+ "ia",
39
+ "ian",
40
+ "iang",
41
+ "iao",
42
+ "ie",
43
+ "in",
44
+ "ing",
45
+ "iong",
46
+ "ir",
47
+ "iu",
48
+ "j",
49
+ "k",
50
+ "l",
51
+ "m",
52
+ "n",
53
+ "o",
54
+ "ong",
55
+ "ou",
56
+ "p",
57
+ "q",
58
+ "r",
59
+ "s",
60
+ "sh",
61
+ "t",
62
+ "u",
63
+ "ua",
64
+ "uai",
65
+ "uan",
66
+ "uang",
67
+ "ui",
68
+ "un",
69
+ "uo",
70
+ "v",
71
+ "van",
72
+ "ve",
73
+ "vn",
74
+ "w",
75
+ "x",
76
+ "y",
77
+ "z",
78
+ "zh",
79
+ "AA",
80
+ "EE",
81
+ "OO",
82
+ ]
83
+ NUM_ZH_TONES = 6
84
+
85
+ # Japanese
86
+ JP_SYMBOLS = [
87
+ "N",
88
+ "a",
89
+ "a:",
90
+ "b",
91
+ "by",
92
+ "ch",
93
+ "d",
94
+ "dy",
95
+ "e",
96
+ "e:",
97
+ "f",
98
+ "g",
99
+ "gy",
100
+ "h",
101
+ "hy",
102
+ "i",
103
+ "i:",
104
+ "j",
105
+ "k",
106
+ "ky",
107
+ "m",
108
+ "my",
109
+ "n",
110
+ "ny",
111
+ "o",
112
+ "o:",
113
+ "p",
114
+ "py",
115
+ "q",
116
+ "r",
117
+ "ry",
118
+ "s",
119
+ "sh",
120
+ "t",
121
+ "ts",
122
+ "ty",
123
+ "u",
124
+ "u:",
125
+ "w",
126
+ "y",
127
+ "z",
128
+ "zy",
129
+ "aa",
130
+ "ae",
131
+ "ah",
132
+ "ao",
133
+ "aw",
134
+ "ay",
135
+ "dh",
136
+ "eh",
137
+ "er",
138
+ "ey",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "l",
144
+ "ng",
145
+ "ow",
146
+ "oy",
147
+ "sh",
148
+ "th",
149
+ "uh",
150
+ "uw",
151
+ "V",
152
+ "zh",
153
+ "E",
154
+ "En",
155
+ "ai",
156
+ "an",
157
+ "ang",
158
+ "c",
159
+ "ei",
160
+ "en",
161
+ "eng",
162
+ "i0",
163
+ "ia",
164
+ "ian",
165
+ "iang",
166
+ "iao",
167
+ "ie",
168
+ "in",
169
+ "ing",
170
+ "iong",
171
+ "ir",
172
+ "iu",
173
+ "ong",
174
+ "ou",
175
+ "ua",
176
+ "uai",
177
+ "uan",
178
+ "uang",
179
+ "ui",
180
+ "un",
181
+ "uo",
182
+ "v",
183
+ "van",
184
+ "ve",
185
+ "vn",
186
+ "AA",
187
+ "EE",
188
+ "OO",
189
+ ]
190
+ NUM_JP_TONES = 12
191
+
192
+ # English
193
+ EN_SYMBOLS = [
194
+ "aa",
195
+ "ae",
196
+ "ah",
197
+ "ao",
198
+ "aw",
199
+ "ay",
200
+ "b",
201
+ "ch",
202
+ "d",
203
+ "dh",
204
+ "eh",
205
+ "er",
206
+ "ey",
207
+ "f",
208
+ "g",
209
+ "hh",
210
+ "ih",
211
+ "iy",
212
+ "jh",
213
+ "k",
214
+ "l",
215
+ "m",
216
+ "n",
217
+ "ng",
218
+ "ow",
219
+ "oy",
220
+ "p",
221
+ "r",
222
+ "s",
223
+ "sh",
224
+ "t",
225
+ "th",
226
+ "uh",
227
+ "uw",
228
+ "V",
229
+ "w",
230
+ "y",
231
+ "z",
232
+ "zh",
233
+ ]
234
+ NUM_EN_TONES = 4
235
+
236
+ # Combine all symbols
237
+ NORMAL_SYMBOLS = sorted(set(ZH_SYMBOLS + JP_SYMBOLS + EN_SYMBOLS))
238
+ SYMBOLS = [PAD] + NORMAL_SYMBOLS + PUNCTUATION_SYMBOLS
239
+ SIL_PHONEMES_IDS = [SYMBOLS.index(i) for i in PUNCTUATION_SYMBOLS]
240
+
241
+ # Combine all tones
242
+ NUM_TONES = NUM_ZH_TONES + NUM_JP_TONES + NUM_EN_TONES
243
+
244
+ # Language maps
245
+ LANGUAGE_ID_MAP = {"ZH": 0, "JP": 1, "EN": 2}
246
+ NUM_LANGUAGES = len(LANGUAGE_ID_MAP.keys())
247
+
248
+ # Language tone start map
249
+ LANGUAGE_TONE_START_MAP = {
250
+ "ZH": 0,
251
+ "JP": NUM_ZH_TONES,
252
+ "EN": NUM_ZH_TONES + NUM_JP_TONES,
253
+ }
254
+
255
+
256
+ if __name__ == "__main__":
257
+ a = set(ZH_SYMBOLS)
258
+ b = set(EN_SYMBOLS)
259
+ print(sorted(a & b))
train_ms.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import gc
4
+ import os
5
+ import platform
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from huggingface_hub import HfApi
10
+ from torch.cuda.amp import GradScaler, autocast
11
+ from torch.nn import functional as F
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.tensorboard import SummaryWriter
15
+ from tqdm import tqdm
16
+
17
+ # logging.getLogger("numba").setLevel(logging.WARNING)
18
+ import default_style
19
+ from config import get_config
20
+ from data_utils import (
21
+ DistributedBucketSampler,
22
+ TextAudioSpeakerCollate,
23
+ TextAudioSpeakerLoader,
24
+ )
25
+ from losses import WavLMLoss, discriminator_loss, feature_loss, generator_loss, kl_loss
26
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
27
+ from style_bert_vits2.logging import logger
28
+ from style_bert_vits2.models import commons, utils
29
+ from style_bert_vits2.models.hyper_parameters import HyperParameters
30
+ from style_bert_vits2.models.models import (
31
+ DurationDiscriminator,
32
+ MultiPeriodDiscriminator,
33
+ SynthesizerTrn,
34
+ WavLMDiscriminator,
35
+ )
36
+ from style_bert_vits2.nlp.symbols import SYMBOLS
37
+ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT
38
+
39
+
40
+ torch.backends.cuda.matmul.allow_tf32 = True
41
+ torch.backends.cudnn.allow_tf32 = (
42
+ True # If encontered training problem,please try to disable TF32.
43
+ )
44
+ torch.set_float32_matmul_precision("medium")
45
+ torch.backends.cuda.sdp_kernel("flash")
46
+ torch.backends.cuda.enable_flash_sdp(True)
47
+ torch.backends.cuda.enable_mem_efficient_sdp(
48
+ True
49
+ ) # Not available if torch version is lower than 2.0
50
+ torch.backends.cuda.enable_math_sdp(True)
51
+
52
+ config = get_config()
53
+ global_step = 0
54
+
55
+ api = HfApi()
56
+
57
+
58
+ def run():
59
+ # Command line configuration is not recommended unless necessary, use config.yml
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument(
62
+ "-c",
63
+ "--config",
64
+ type=str,
65
+ default=config.train_ms_config.config_path,
66
+ help="JSON file for configuration",
67
+ )
68
+ parser.add_argument(
69
+ "-m",
70
+ "--model",
71
+ type=str,
72
+ help="数据集文件夹路径,请注意,数据不再默认放在/logs文件夹下。如果需要用命令行配置,请声明相对于根目录的路径",
73
+ default=config.dataset_path,
74
+ )
75
+ parser.add_argument(
76
+ "--assets_root",
77
+ type=str,
78
+ help="Root directory of model assets needed for inference.",
79
+ default=config.assets_root,
80
+ )
81
+ parser.add_argument(
82
+ "--skip_default_style",
83
+ action="store_true",
84
+ help="Skip saving default style config and mean vector.",
85
+ )
86
+ parser.add_argument(
87
+ "--no_progress_bar",
88
+ action="store_true",
89
+ help="Do not show the progress bar while training.",
90
+ )
91
+ parser.add_argument(
92
+ "--speedup",
93
+ action="store_true",
94
+ help="Speed up training by disabling logging and evaluation.",
95
+ )
96
+ parser.add_argument(
97
+ "--repo_id",
98
+ help="Huggingface model repo id to backup the model.",
99
+ default=None,
100
+ )
101
+ parser.add_argument(
102
+ "--not_use_custom_batch_sampler",
103
+ help="Don't use custom batch sampler for training, which was used in the version < 2.5",
104
+ action="store_true",
105
+ )
106
+ args = parser.parse_args()
107
+
108
+ # Set log file
109
+ model_dir = os.path.join(args.model, config.train_ms_config.model_dir)
110
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
111
+ logger.add(os.path.join(args.model, f"train_{timestamp}.log"))
112
+
113
+ # Parsing environment variables
114
+ envs = config.train_ms_config.env
115
+ for env_name, env_value in envs.items():
116
+ if env_name not in os.environ.keys():
117
+ logger.info(f"Loading configuration from config {env_value!s}")
118
+ os.environ[env_name] = str(env_value)
119
+ logger.info(
120
+ "Loading environment variables \nMASTER_ADDR: {},\nMASTER_PORT: {},\nWORLD_SIZE: {},\nRANK: {},\nLOCAL_RANK: {}".format(
121
+ os.environ["MASTER_ADDR"],
122
+ os.environ["MASTER_PORT"],
123
+ os.environ["WORLD_SIZE"],
124
+ os.environ["RANK"],
125
+ os.environ["LOCAL_RANK"],
126
+ )
127
+ )
128
+
129
+ backend = "nccl"
130
+ if platform.system() == "Windows":
131
+ backend = "gloo" # If Windows,switch to gloo backend.
132
+ dist.init_process_group(
133
+ backend=backend,
134
+ init_method="env://",
135
+ timeout=datetime.timedelta(seconds=300),
136
+ ) # Use torchrun instead of mp.spawn
137
+ rank = dist.get_rank()
138
+ local_rank = int(os.environ["LOCAL_RANK"])
139
+ n_gpus = dist.get_world_size()
140
+
141
+ hps = HyperParameters.load_from_json(args.config)
142
+ # This is needed because we have to pass values to `train_and_evaluate()`
143
+ hps.model_dir = model_dir
144
+ hps.speedup = args.speedup
145
+ hps.repo_id = args.repo_id
146
+
147
+ # 比较路径是否相同
148
+ if os.path.realpath(args.config) != os.path.realpath(
149
+ config.train_ms_config.config_path
150
+ ):
151
+ with open(args.config, encoding="utf-8") as f:
152
+ data = f.read()
153
+ os.makedirs(os.path.dirname(config.train_ms_config.config_path), exist_ok=True)
154
+ with open(config.train_ms_config.config_path, "w", encoding="utf-8") as f:
155
+ f.write(data)
156
+
157
+ """
158
+ Path constants are a bit complicated...
159
+ TODO: Refactor or rename these?
160
+ (Both `config.yml` and `config.json` are used, which is confusing I think.)
161
+
162
+ args.model: For saving all info needed for training.
163
+ default: `Data/{model_name}`.
164
+ hps.model_dir := model_dir: For saving checkpoints (for resuming training).
165
+ default: `Data/{model_name}/models`.
166
+ (Use `hps` since we have to pass `model_dir` to `train_and_evaluate()`.
167
+
168
+ args.assets_root: The root directory of model assets needed for inference.
169
+ default: config.assets_root == `model_assets`.
170
+
171
+ config.out_dir: The directory for model assets of this model (for inference).
172
+ default: `model_assets/{model_name}`.
173
+ """
174
+
175
+ if args.repo_id is not None:
176
+ # First try to upload config.json to check if the repo exists
177
+ try:
178
+ api.upload_file(
179
+ path_or_fileobj=args.config,
180
+ path_in_repo=f"Data/{config.model_name}/config.json",
181
+ repo_id=hps.repo_id,
182
+ )
183
+ except Exception as e:
184
+ logger.error(e)
185
+ logger.error(
186
+ f"Failed to upload files to the repo {hps.repo_id}. Please check if the repo exists and you have logged in using `huggingface-cli login`."
187
+ )
188
+ raise e
189
+ # Upload Data dir for resuming training
190
+ api.upload_folder(
191
+ repo_id=hps.repo_id,
192
+ folder_path=config.dataset_path,
193
+ path_in_repo=f"Data/{config.model_name}",
194
+ delete_patterns="*.pth", # Only keep the latest checkpoint
195
+ run_as_future=True,
196
+ )
197
+
198
+ os.makedirs(config.out_dir, exist_ok=True)
199
+
200
+ if not args.skip_default_style:
201
+ default_style.save_styles_by_dirs(
202
+ os.path.join(args.model, "wavs"),
203
+ config.out_dir,
204
+ config_path=args.config,
205
+ config_output_path=os.path.join(config.out_dir, "config.json"),
206
+ )
207
+
208
+ torch.manual_seed(hps.train.seed)
209
+ torch.cuda.set_device(local_rank)
210
+
211
+ global global_step
212
+ writer = None
213
+ writer_eval = None
214
+ if rank == 0 and not args.speedup:
215
+ # logger = utils.get_logger(hps.model_dir)
216
+ # logger.info(hps)
217
+ utils.check_git_hash(model_dir)
218
+ writer = SummaryWriter(log_dir=model_dir)
219
+ writer_eval = SummaryWriter(log_dir=os.path.join(model_dir, "eval"))
220
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
221
+ collate_fn = TextAudioSpeakerCollate()
222
+ if not args.not_use_custom_batch_sampler:
223
+ train_sampler = DistributedBucketSampler(
224
+ train_dataset,
225
+ hps.train.batch_size,
226
+ [32, 300, 400, 500, 600, 700, 800, 900, 1000],
227
+ num_replicas=n_gpus,
228
+ rank=rank,
229
+ shuffle=True,
230
+ )
231
+ train_loader = DataLoader(
232
+ train_dataset,
233
+ # メモリ消費量を減らそうとnum_workersを1にしてみる
234
+ # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2),
235
+ num_workers=1,
236
+ shuffle=False,
237
+ pin_memory=True,
238
+ collate_fn=collate_fn,
239
+ batch_sampler=train_sampler,
240
+ # batch_size=hps.train.batch_size,
241
+ persistent_workers=True,
242
+ # これもメモリ消費量を減らそうとしてコメントアウト
243
+ # prefetch_factor=6,
244
+ )
245
+ else:
246
+ train_loader = DataLoader(
247
+ train_dataset,
248
+ # メモリ消費量を減らそうとnum_workersを1にしてみる
249
+ # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2),
250
+ num_workers=1,
251
+ shuffle=True,
252
+ pin_memory=True,
253
+ collate_fn=collate_fn,
254
+ # batch_sampler=train_sampler,
255
+ batch_size=hps.train.batch_size,
256
+ persistent_workers=True,
257
+ # これもメモリ消費量を減らそうとしてコメントアウト
258
+ # prefetch_factor=6,
259
+ )
260
+ eval_dataset = None
261
+ eval_loader = None
262
+ if rank == 0 and not args.speedup:
263
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
264
+ eval_loader = DataLoader(
265
+ eval_dataset,
266
+ num_workers=0,
267
+ shuffle=False,
268
+ batch_size=1,
269
+ pin_memory=True,
270
+ drop_last=False,
271
+ collate_fn=collate_fn,
272
+ )
273
+ if hps.model.use_noise_scaled_mas is True:
274
+ logger.info("Using noise scaled MAS for VITS2")
275
+ mas_noise_scale_initial = 0.01
276
+ noise_scale_delta = 2e-6
277
+ else:
278
+ logger.info("Using normal MAS for VITS1")
279
+ mas_noise_scale_initial = 0.0
280
+ noise_scale_delta = 0.0
281
+ if hps.model.use_duration_discriminator is True:
282
+ logger.info("Using duration discriminator for VITS2")
283
+ net_dur_disc = DurationDiscriminator(
284
+ hps.model.hidden_channels,
285
+ hps.model.hidden_channels,
286
+ 3,
287
+ 0.1,
288
+ gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
289
+ ).cuda(local_rank)
290
+
291
+ else:
292
+ net_dur_disc = None
293
+ if hps.model.use_wavlm_discriminator is True:
294
+ net_wd = WavLMDiscriminator(
295
+ hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel
296
+ ).cuda(local_rank)
297
+ else:
298
+ net_wd = None
299
+
300
+
301
+
302
+ if hps.model.use_spk_conditioned_encoder is True:
303
+ if hps.data.n_speakers == 0:
304
+ raise ValueError(
305
+ "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
306
+ )
307
+ else:
308
+ logger.info("Using normal encoder for VITS1")
309
+
310
+ net_g = SynthesizerTrn(
311
+ len(SYMBOLS),
312
+ hps.data.filter_length // 2 + 1,
313
+ hps.train.segment_size // hps.data.hop_length,
314
+ n_speakers=hps.data.n_speakers,
315
+ mas_noise_scale_initial=mas_noise_scale_initial,
316
+ noise_scale_delta=noise_scale_delta,
317
+ # hps.model 以下のすべての値を引数に渡す
318
+ use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder,
319
+ use_noise_scaled_mas=hps.model.use_noise_scaled_mas,
320
+ use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder,
321
+ use_duration_discriminator=hps.model.use_duration_discriminator,
322
+ use_wavlm_discriminator=hps.model.use_wavlm_discriminator,
323
+ inter_channels=hps.model.inter_channels,
324
+ hidden_channels=hps.model.hidden_channels,
325
+ filter_channels=hps.model.filter_channels,
326
+ n_heads=hps.model.n_heads,
327
+ n_layers=hps.model.n_layers,
328
+ kernel_size=hps.model.kernel_size,
329
+ p_dropout=hps.model.p_dropout,
330
+ resblock=hps.model.resblock,
331
+ resblock_kernel_sizes=hps.model.resblock_kernel_sizes,
332
+ resblock_dilation_sizes=hps.model.resblock_dilation_sizes,
333
+ upsample_rates=hps.model.upsample_rates,
334
+ upsample_initial_channel=hps.model.upsample_initial_channel,
335
+ upsample_kernel_sizes=hps.model.upsample_kernel_sizes,
336
+ n_layers_q=hps.model.n_layers_q,
337
+ use_spectral_norm=hps.model.use_spectral_norm,
338
+ gin_channels=hps.model.gin_channels,
339
+ slm=hps.model.slm,
340
+ ).cuda(local_rank)
341
+
342
+ if getattr(hps.train, "freeze_ZH_bert", False):
343
+ logger.info("Freezing ZH bert encoder !!!")
344
+ for param in net_g.enc_p.bert_proj.parameters():
345
+ param.requires_grad = False
346
+
347
+ if getattr(hps.train, "freeze_EN_bert", False):
348
+ logger.info("Freezing EN bert encoder !!!")
349
+ for param in net_g.enc_p.en_bert_proj.parameters():
350
+ param.requires_grad = False
351
+
352
+ if getattr(hps.train, "freeze_JP_bert", False):
353
+ logger.info("Freezing JP bert encoder !!!")
354
+ for param in net_g.enc_p.ja_bert_proj.parameters():
355
+ param.requires_grad = False
356
+ if getattr(hps.train, "freeze_style", False):
357
+ logger.info("Freezing style encoder !!!")
358
+ for param in net_g.enc_p.style_proj.parameters():
359
+ param.requires_grad = False
360
+
361
+ if getattr(hps.train, "freeze_decoder", False):
362
+ logger.info("Freezing decoder !!!")
363
+ for param in net_g.dec.parameters():
364
+ param.requires_grad = False
365
+
366
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
367
+ optim_g = torch.optim.AdamW(
368
+ filter(lambda p: p.requires_grad, net_g.parameters()),
369
+ hps.train.learning_rate,
370
+ betas=hps.train.betas,
371
+ eps=hps.train.eps,
372
+ )
373
+ optim_d = torch.optim.AdamW(
374
+ net_d.parameters(),
375
+ hps.train.learning_rate,
376
+ betas=hps.train.betas,
377
+ eps=hps.train.eps,
378
+ )
379
+ if net_dur_disc is not None:
380
+ optim_dur_disc = torch.optim.AdamW(
381
+ net_dur_disc.parameters(),
382
+ hps.train.learning_rate,
383
+ betas=hps.train.betas,
384
+ eps=hps.train.eps,
385
+ )
386
+ else:
387
+ optim_dur_disc = None
388
+
389
+
390
+
391
+ if net_wd is not None:
392
+ optim_wd = torch.optim.AdamW(
393
+ net_wd.parameters(),
394
+ hps.train.learning_rate,
395
+ betas=hps.train.betas,
396
+ eps=hps.train.eps,
397
+ )
398
+ else:
399
+ optim_wd = None
400
+
401
+
402
+ net_g = DDP(net_g, device_ids=[local_rank])
403
+ net_d = DDP(net_d, device_ids=[local_rank])
404
+ dur_resume_lr = None
405
+ if net_dur_disc is not None:
406
+ net_dur_disc = DDP(
407
+ net_dur_disc, device_ids=[local_rank], find_unused_parameters=True
408
+ )
409
+
410
+ if net_wd is not None:
411
+ net_wd = DDP(
412
+ net_wd,
413
+ device_ids=[local_rank],
414
+ # bucket_cap_mb=512
415
+ )
416
+
417
+
418
+
419
+ if utils.is_resuming(model_dir):
420
+ if net_dur_disc is not None:
421
+ _, _, dur_resume_lr, epoch_str = utils.checkpoints.load_checkpoint(
422
+ utils.checkpoints.get_latest_checkpoint_path(model_dir, "DUR_*.pth"),
423
+ net_dur_disc,
424
+ optim_dur_disc,
425
+ skip_optimizer=hps.train.skip_optimizer,
426
+ )
427
+ if not optim_dur_disc.param_groups[0].get("initial_lr"):
428
+ optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
429
+
430
+ if net_wd is not None:
431
+ try:
432
+ _, optim_wd, wd_resume_lr, epoch_str = (
433
+ utils.checkpoints.load_checkpoint(
434
+ utils.checkpoints.get_latest_checkpoint_path(
435
+ model_dir, "WD_*.pth"
436
+ ),
437
+ net_wd,
438
+ optim_wd,
439
+ skip_optimizer=hps.train.skip_optimizer,
440
+ )
441
+ )
442
+ if not optim_wd.param_groups[0].get("initial_lr"):
443
+ optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
444
+ except:
445
+ if not optim_wd.param_groups[0].get("initial_lr"):
446
+ optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
447
+ logger.info("Initialize wavlm")
448
+
449
+
450
+ _, optim_g, g_resume_lr, epoch_str = utils.checkpoints.load_checkpoint(
451
+ utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth"),
452
+ net_g,
453
+ optim_g,
454
+ skip_optimizer=hps.train.skip_optimizer,
455
+ )
456
+ _, optim_d, d_resume_lr, epoch_str = utils.checkpoints.load_checkpoint(
457
+ utils.checkpoints.get_latest_checkpoint_path(model_dir, "D_*.pth"),
458
+ net_d,
459
+ optim_d,
460
+ skip_optimizer=hps.train.skip_optimizer,
461
+ )
462
+ if not optim_g.param_groups[0].get("initial_lr"):
463
+ optim_g.param_groups[0]["initial_lr"] = g_resume_lr
464
+ if not optim_d.param_groups[0].get("initial_lr"):
465
+ optim_d.param_groups[0]["initial_lr"] = d_resume_lr
466
+
467
+ epoch_str = max(epoch_str, 1)
468
+ # global_step = (epoch_str - 1) * len(train_loader)
469
+ global_step = int(
470
+ utils.get_steps(
471
+ utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth")
472
+ )
473
+ )
474
+ logger.info(
475
+ f"******************Found the model. Current epoch is {epoch_str}, gloabl step is {global_step}*********************"
476
+ )
477
+ else:
478
+ try:
479
+ _ = utils.safetensors.load_safetensors(
480
+ os.path.join(model_dir, "G_0.safetensors"), net_g
481
+ )
482
+ _ = utils.safetensors.load_safetensors(
483
+ os.path.join(model_dir, "D_0.safetensors"), net_d
484
+ )
485
+ if net_dur_disc is not None:
486
+ _ = utils.safetensors.load_safetensors(
487
+ os.path.join(model_dir, "DUR_0.safetensors"), net_dur_disc
488
+ )
489
+
490
+ if net_wd is not None:
491
+ _ = utils.safetensors.load_safetensors(
492
+ os.path.join(model_dir, "WD_0.safetensors"), net_wd
493
+ )
494
+
495
+ logger.info("Loaded the pretrained models.")
496
+ except Exception as e:
497
+ logger.warning(e)
498
+ logger.warning(
499
+ "It seems that you are not using the pretrained models, so we will train from scratch."
500
+ )
501
+ finally:
502
+ epoch_str = 1
503
+ global_step = 0
504
+
505
+ def lr_lambda(epoch):
506
+ """
507
+ Learning rate scheduler for warmup and exponential decay.
508
+ - During the warmup period, the learning rate increases linearly.
509
+ - After the warmup period, the learning rate decreases exponentially.
510
+ """
511
+ if epoch < hps.train.warmup_epochs:
512
+ return float(epoch) / float(max(1, hps.train.warmup_epochs))
513
+ else:
514
+ return hps.train.lr_decay ** (epoch - hps.train.warmup_epochs)
515
+
516
+ scheduler_last_epoch = epoch_str - 2
517
+ scheduler_g = torch.optim.lr_scheduler.LambdaLR(
518
+ optim_g, lr_lambda=lr_lambda, last_epoch=scheduler_last_epoch
519
+ )
520
+ scheduler_d = torch.optim.lr_scheduler.LambdaLR(
521
+ optim_d, lr_lambda=lr_lambda, last_epoch=scheduler_last_epoch
522
+ )
523
+ if net_dur_disc is not None:
524
+ scheduler_dur_disc = torch.optim.lr_scheduler.LambdaLR(
525
+ optim_dur_disc, lr_lambda=lr_lambda, last_epoch=scheduler_last_epoch
526
+ )
527
+ else:
528
+ scheduler_dur_disc = None
529
+
530
+
531
+ if net_wd is not None:
532
+ scheduler_wd = torch.optim.lr_scheduler.LambdaLR(
533
+ optim_wd, lr_lambda=lr_lambda, last_epoch=scheduler_last_epoch
534
+ )
535
+ wl = WavLMLoss(
536
+ hps.model.slm.model,
537
+ net_wd,
538
+ hps.data.sampling_rate,
539
+ hps.model.slm.sr,
540
+ ).to(local_rank)
541
+ else:
542
+ scheduler_wd = None
543
+ wl = None
544
+
545
+
546
+
547
+ scaler = GradScaler(enabled=hps.train.bf16_run)
548
+ logger.info("Start training.")
549
+
550
+ diff = abs(
551
+ epoch_str * len(train_loader) - (hps.train.epochs + 1) * len(train_loader)
552
+ )
553
+ pbar = None
554
+ if not args.no_progress_bar:
555
+ pbar = tqdm(
556
+ total=global_step + diff,
557
+ initial=global_step,
558
+ smoothing=0.05,
559
+ file=SAFE_STDOUT,
560
+ )
561
+ initial_step = global_step
562
+
563
+ for epoch in range(epoch_str, hps.train.epochs + 1):
564
+ if rank == 0:
565
+ train_and_evaluate(
566
+ rank,
567
+ local_rank,
568
+ epoch,
569
+ hps,
570
+ [net_g, net_d, net_dur_disc, net_wd, wl],
571
+ [optim_g, optim_d, optim_dur_disc, optim_wd],
572
+ [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
573
+ scaler,
574
+ [train_loader, eval_loader],
575
+ logger,
576
+ [writer, writer_eval],
577
+ pbar,
578
+ initial_step,
579
+ )
580
+ else:
581
+ train_and_evaluate(
582
+ rank,
583
+ local_rank,
584
+ epoch,
585
+ hps,
586
+ [net_g, net_d, net_dur_disc, net_wd, wl],
587
+ [optim_g, optim_d, optim_dur_disc, optim_wd],
588
+ [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
589
+ scaler,
590
+ [train_loader, None],
591
+ None,
592
+ None,
593
+ pbar,
594
+ initial_step,
595
+ )
596
+ scheduler_g.step()
597
+ scheduler_d.step()
598
+ if net_dur_disc is not None:
599
+ scheduler_dur_disc.step()
600
+ if net_wd is not None:
601
+ scheduler_wd.step()
602
+ if epoch == hps.train.epochs:
603
+ # Save the final models
604
+ assert optim_g is not None
605
+ utils.checkpoints.save_checkpoint(
606
+ net_g,
607
+ optim_g,
608
+ hps.train.learning_rate,
609
+ epoch,
610
+ os.path.join(model_dir, f"G_{global_step}.pth"),
611
+ )
612
+ assert optim_d is not None
613
+ utils.checkpoints.save_checkpoint(
614
+ net_d,
615
+ optim_d,
616
+ hps.train.learning_rate,
617
+ epoch,
618
+ os.path.join(model_dir, f"D_{global_step}.pth"),
619
+ )
620
+ if net_dur_disc is not None:
621
+ assert optim_dur_disc is not None
622
+ utils.checkpoints.save_checkpoint(
623
+ net_dur_disc,
624
+ optim_dur_disc,
625
+ hps.train.learning_rate,
626
+ epoch,
627
+ os.path.join(model_dir, f"DUR_{global_step}.pth"),
628
+ )
629
+
630
+
631
+ if net_wd is not None:
632
+ assert optim_wd is not None
633
+ utils.checkpoints.save_checkpoint(
634
+ net_wd,
635
+ optim_wd,
636
+ hps.train.learning_rate,
637
+ epoch,
638
+ os.path.join(model_dir, f"WD_{global_step}.pth"),
639
+ )
640
+
641
+
642
+ utils.safetensors.save_safetensors(
643
+ net_g,
644
+ epoch,
645
+ os.path.join(
646
+ config.out_dir,
647
+ f"{config.model_name}_e{epoch}_s{global_step}.safetensors",
648
+ ),
649
+ for_infer=True,
650
+ )
651
+ if hps.repo_id is not None:
652
+ future1 = api.upload_folder(
653
+ repo_id=hps.repo_id,
654
+ folder_path=config.dataset_path,
655
+ path_in_repo=f"Data/{config.model_name}",
656
+ delete_patterns="*.pth", # Only keep the latest checkpoint
657
+ run_as_future=True,
658
+ )
659
+ future2 = api.upload_folder(
660
+ repo_id=hps.repo_id,
661
+ folder_path=config.out_dir,
662
+ path_in_repo=f"model_assets/{config.model_name}",
663
+ run_as_future=True,
664
+ )
665
+ try:
666
+ future1.result()
667
+ future2.result()
668
+ except Exception as e:
669
+ logger.error(e)
670
+
671
+ if pbar is not None:
672
+ pbar.close()
673
+
674
+
675
+ def train_and_evaluate(
676
+ rank,
677
+ local_rank,
678
+ epoch,
679
+ hps: HyperParameters,
680
+ nets,
681
+ optims,
682
+ schedulers,
683
+ scaler,
684
+ loaders,
685
+ logger,
686
+ writers,
687
+ pbar: tqdm,
688
+ initial_step: int,
689
+ ):
690
+ net_g, net_d, net_dur_disc, net_wd, wl = nets
691
+ optim_g, optim_d, optim_dur_disc, optim_wd = optims
692
+ scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
693
+ train_loader, eval_loader = loaders
694
+ if writers is not None:
695
+ writer, writer_eval = writers
696
+
697
+ train_loader.batch_sampler.set_epoch(epoch)
698
+ global global_step
699
+
700
+ net_g.train()
701
+ net_d.train()
702
+ if net_dur_disc is not None:
703
+ net_dur_disc.train()
704
+ if net_wd is not None:
705
+ net_wd.train()
706
+ for batch_idx, (
707
+ x,
708
+ x_lengths,
709
+ spec,
710
+ spec_lengths,
711
+ y,
712
+ y_lengths,
713
+ speakers,
714
+ tone,
715
+ language,
716
+ bert,
717
+ ja_bert,
718
+ en_bert,
719
+ style_vec,
720
+ ) in enumerate(train_loader):
721
+ if net_g.module.use_noise_scaled_mas:
722
+ current_mas_noise_scale = (
723
+ net_g.module.mas_noise_scale_initial
724
+ - net_g.module.noise_scale_delta * global_step
725
+ )
726
+ net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
727
+ x, x_lengths = x.cuda(local_rank, non_blocking=True), x_lengths.cuda(
728
+ local_rank, non_blocking=True
729
+ )
730
+ spec, spec_lengths = spec.cuda(
731
+ local_rank, non_blocking=True
732
+ ), spec_lengths.cuda(local_rank, non_blocking=True)
733
+ y, y_lengths = y.cuda(local_rank, non_blocking=True), y_lengths.cuda(
734
+ local_rank, non_blocking=True
735
+ )
736
+ speakers = speakers.cuda(local_rank, non_blocking=True)
737
+ tone = tone.cuda(local_rank, non_blocking=True)
738
+ language = language.cuda(local_rank, non_blocking=True)
739
+ bert = bert.cuda(local_rank, non_blocking=True)
740
+ ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
741
+ en_bert = en_bert.cuda(local_rank, non_blocking=True)
742
+ style_vec = style_vec.cuda(local_rank, non_blocking=True)
743
+
744
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
745
+ (
746
+ y_hat,
747
+ l_length,
748
+ attn,
749
+ ids_slice,
750
+ x_mask,
751
+ z_mask,
752
+ (z, z_p, m_p, logs_p, m_q, logs_q),
753
+ (hidden_x, logw, logw_), # , logw_sdp),
754
+ g,
755
+ ) = net_g(
756
+ x,
757
+ x_lengths,
758
+ spec,
759
+ spec_lengths,
760
+ speakers,
761
+ tone,
762
+ language,
763
+ bert,
764
+ ja_bert,
765
+ en_bert,
766
+ style_vec,
767
+ )
768
+ mel = spec_to_mel_torch(
769
+ spec,
770
+ hps.data.filter_length,
771
+ hps.data.n_mel_channels,
772
+ hps.data.sampling_rate,
773
+ hps.data.mel_fmin,
774
+ hps.data.mel_fmax,
775
+ )
776
+ y_mel = commons.slice_segments(
777
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
778
+ )
779
+ y_hat_mel = mel_spectrogram_torch(
780
+ y_hat.squeeze(1).float(),
781
+ hps.data.filter_length,
782
+ hps.data.n_mel_channels,
783
+ hps.data.sampling_rate,
784
+ hps.data.hop_length,
785
+ hps.data.win_length,
786
+ hps.data.mel_fmin,
787
+ hps.data.mel_fmax,
788
+ )
789
+
790
+ y = commons.slice_segments(
791
+ y, ids_slice * hps.data.hop_length, hps.train.segment_size
792
+ ) # slice
793
+
794
+ # Discriminator
795
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
796
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
797
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
798
+ y_d_hat_r, y_d_hat_g
799
+ )
800
+ loss_disc_all = loss_disc
801
+ if net_dur_disc is not None:
802
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(
803
+ hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
804
+ )
805
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
806
+ # TODO: I think need to mean using the mask, but for now, just mean all
807
+ (
808
+ loss_dur_disc,
809
+ losses_dur_disc_r,
810
+ losses_dur_disc_g,
811
+ ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
812
+ loss_dur_disc_all = loss_dur_disc
813
+ optim_dur_disc.zero_grad()
814
+ scaler.scale(loss_dur_disc_all).backward()
815
+ scaler.unscale_(optim_dur_disc)
816
+ commons.clip_grad_value_(net_dur_disc.parameters(), None)
817
+ scaler.step(optim_dur_disc)
818
+
819
+
820
+ if net_wd is not None:
821
+ # logger.debug(f"y.shape: {y.shape}, y_hat.shape: {y_hat.shape}")
822
+ # shape: (batch, 1, time)
823
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
824
+ loss_slm = wl.discriminator(
825
+ y.detach().squeeze(1), y_hat.detach().squeeze(1)
826
+ ).mean()
827
+ optim_wd.zero_grad()
828
+ scaler.scale(loss_slm).backward()
829
+ scaler.unscale_(optim_wd)
830
+ # torch.nn.utils.clip_grad_norm_(parameters=net_wd.parameters(), max_norm=200)
831
+ grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None)
832
+ scaler.step(optim_wd)
833
+
834
+
835
+
836
+ optim_d.zero_grad()
837
+ scaler.scale(loss_disc_all).backward()
838
+ scaler.unscale_(optim_d)
839
+ if getattr(hps.train, "bf16_run", False):
840
+ torch.nn.utils.clip_grad_norm_(parameters=net_d.parameters(), max_norm=200)
841
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
842
+ scaler.step(optim_d)
843
+
844
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
845
+ # Generator
846
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
847
+ if net_dur_disc is not None:
848
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
849
+ if net_wd is not None:
850
+ loss_lm = wl(y.detach().squeeze(1), y_hat.squeeze(1)).mean()
851
+ loss_lm_gen = wl.generator(y_hat.squeeze(1))
852
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
853
+ loss_dur = torch.sum(l_length.float())
854
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
855
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
856
+
857
+ loss_fm = feature_loss(fmap_r, fmap_g)
858
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
859
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
860
+ if net_dur_disc is not None:
861
+ loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
862
+
863
+ if net_wd is not None:
864
+ loss_gen_all += loss_dur_gen + loss_lm + loss_lm_gen
865
+
866
+
867
+ loss_gen_all += loss_dur_gen
868
+ optim_g.zero_grad()
869
+ scaler.scale(loss_gen_all).backward()
870
+ scaler.unscale_(optim_g)
871
+ if getattr(hps.train, "bf16_run", False):
872
+ torch.nn.utils.clip_grad_norm_(parameters=net_g.parameters(), max_norm=500)
873
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
874
+ scaler.step(optim_g)
875
+ scaler.update()
876
+
877
+ if rank == 0:
878
+ if global_step % hps.train.log_interval == 0 and not hps.speedup:
879
+ lr = optim_g.param_groups[0]["lr"]
880
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
881
+ # logger.info(
882
+ # "Train Epoch: {} [{:.0f}%]".format(
883
+ # epoch, 100.0 * batch_idx / len(train_loader)
884
+ # )
885
+ # )
886
+ # logger.info([x.item() for x in losses] + [global_step, lr])
887
+
888
+ scalar_dict = {
889
+ "loss/g/total": loss_gen_all,
890
+ "loss/d/total": loss_disc_all,
891
+ "learning_rate": lr,
892
+ "grad_norm_d": grad_norm_d,
893
+ "grad_norm_g": grad_norm_g,
894
+ }
895
+ scalar_dict.update(
896
+ {
897
+ "loss/g/fm": loss_fm,
898
+ "loss/g/mel": loss_mel,
899
+ "loss/g/dur": loss_dur,
900
+ "loss/g/kl": loss_kl,
901
+ }
902
+ )
903
+ scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
904
+ scalar_dict.update(
905
+ {f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}
906
+ )
907
+ scalar_dict.update(
908
+ {f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}
909
+ )
910
+
911
+ if net_wd is not None:
912
+ scalar_dict.update(
913
+ {
914
+ "loss/wd/total": loss_slm,
915
+ "grad_norm_wd": grad_norm_wd,
916
+ "loss/g/lm": loss_lm,
917
+ "loss/g/lm_gen": loss_lm_gen,
918
+ }
919
+ )
920
+
921
+ # 以降のログは計算が重い気がするし誰も見てない気がするのでコメントアウト
922
+ # image_dict = {
923
+ # "slice/mel_org": utils.plot_spectrogram_to_numpy(
924
+ # y_mel[0].data.cpu().numpy()
925
+ # ),
926
+ # "slice/mel_gen": utils.plot_spectrogram_to_numpy(
927
+ # y_hat_mel[0].data.cpu().numpy()
928
+ # ),
929
+ # "all/mel": utils.plot_spectrogram_to_numpy(
930
+ # mel[0].data.cpu().numpy()
931
+ # ),
932
+ # "all/attn": utils.plot_alignment_to_numpy(
933
+ # attn[0, 0].data.cpu().numpy()
934
+ # ),
935
+ # }
936
+ utils.summarize(
937
+ writer=writer,
938
+ global_step=global_step,
939
+ # images=image_dict,
940
+ scalars=scalar_dict,
941
+ )
942
+
943
+ if (
944
+ global_step % hps.train.eval_interval == 0
945
+ and global_step != 0
946
+ and initial_step != global_step
947
+ ):
948
+ if not hps.speedup:
949
+ evaluate(hps, net_g, eval_loader, writer_eval)
950
+ assert hps.model_dir is not None
951
+ utils.checkpoints.save_checkpoint(
952
+ net_g,
953
+ optim_g,
954
+ hps.train.learning_rate,
955
+ epoch,
956
+ os.path.join(hps.model_dir, f"G_{global_step}.pth"),
957
+ )
958
+ utils.checkpoints.save_checkpoint(
959
+ net_d,
960
+ optim_d,
961
+ hps.train.learning_rate,
962
+ epoch,
963
+ os.path.join(hps.model_dir, f"D_{global_step}.pth"),
964
+ )
965
+ if net_dur_disc is not None:
966
+ utils.checkpoints.save_checkpoint(
967
+ net_dur_disc,
968
+ optim_dur_disc,
969
+ hps.train.learning_rate,
970
+ epoch,
971
+ os.path.join(hps.model_dir, f"DUR_{global_step}.pth"),
972
+ )
973
+ if net_wd is not None:
974
+ utils.checkpoints.save_checkpoint(
975
+ net_wd,
976
+ optim_wd,
977
+ hps.train.learning_rate,
978
+ epoch,
979
+ os.path.join(hps.model_dir, f"WD_{global_step}.pth"),
980
+ )
981
+ keep_ckpts = config.train_ms_config.keep_ckpts
982
+ if keep_ckpts > 0:
983
+ utils.checkpoints.clean_checkpoints(
984
+ model_dir_path=hps.model_dir,
985
+ n_ckpts_to_keep=keep_ckpts,
986
+ sort_by_time=True,
987
+ )
988
+ # Save safetensors (for inference) to `model_assets/{model_name}`
989
+ utils.safetensors.save_safetensors(
990
+ net_g,
991
+ epoch,
992
+ os.path.join(
993
+ config.out_dir,
994
+ f"{config.model_name}_e{epoch}_s{global_step}.safetensors",
995
+ ),
996
+ for_infer=True,
997
+ )
998
+ if hps.repo_id is not None:
999
+ api.upload_folder(
1000
+ repo_id=hps.repo_id,
1001
+ folder_path=config.dataset_path,
1002
+ path_in_repo=f"Data/{config.model_name}",
1003
+ delete_patterns="*.pth", # Only keep the latest checkpoint
1004
+ run_as_future=True,
1005
+ )
1006
+ api.upload_folder(
1007
+ repo_id=hps.repo_id,
1008
+ folder_path=config.out_dir,
1009
+ path_in_repo=f"model_assets/{config.model_name}",
1010
+ run_as_future=True,
1011
+ )
1012
+
1013
+ global_step += 1
1014
+ if pbar is not None:
1015
+ pbar.set_description(
1016
+ f"Epoch {epoch}({100.0 * batch_idx / len(train_loader):.0f}%)/{hps.train.epochs}"
1017
+ )
1018
+ pbar.update()
1019
+ # 本家ではこれをスピードアップのために消すと書かれていたので、一応消してみる
1020
+ # と思ったけどメモリ使用量が減るかもしれないのでつけてみる
1021
+ gc.collect()
1022
+ torch.cuda.empty_cache()
1023
+ if pbar is None and rank == 0:
1024
+ logger.info(f"====> Epoch: {epoch}, step: {global_step}")
1025
+
1026
+
1027
+ def evaluate(hps, generator, eval_loader, writer_eval):
1028
+ generator.eval()
1029
+ image_dict = {}
1030
+ audio_dict = {}
1031
+ print()
1032
+ logger.info("Evaluating ...")
1033
+ with torch.no_grad():
1034
+ for batch_idx, (
1035
+ x,
1036
+ x_lengths,
1037
+ spec,
1038
+ spec_lengths,
1039
+ y,
1040
+ y_lengths,
1041
+ speakers,
1042
+ tone,
1043
+ language,
1044
+ bert,
1045
+ ja_bert,
1046
+ en_bert,
1047
+ style_vec,
1048
+ ) in enumerate(eval_loader):
1049
+ x, x_lengths = x.cuda(), x_lengths.cuda()
1050
+ spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
1051
+ y, y_lengths = y.cuda(), y_lengths.cuda()
1052
+ speakers = speakers.cuda()
1053
+ bert = bert.cuda()
1054
+ ja_bert = ja_bert.cuda()
1055
+ en_bert = en_bert.cuda()
1056
+ tone = tone.cuda()
1057
+ language = language.cuda()
1058
+ style_vec = style_vec.cuda()
1059
+ for use_sdp in [True, False]:
1060
+ y_hat, attn, mask, *_ = generator.module.infer(
1061
+ x,
1062
+ x_lengths,
1063
+ speakers,
1064
+ tone,
1065
+ language,
1066
+ bert,
1067
+ ja_bert,
1068
+ en_bert,
1069
+ style_vec,
1070
+ y=spec,
1071
+ max_len=1000,
1072
+ sdp_ratio=0.0 if not use_sdp else 1.0,
1073
+ )
1074
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
1075
+ # 以降のログは計算が重い気がするし誰も見てない気がするのでコメントアウト
1076
+ # mel = spec_to_mel_torch(
1077
+ # spec,
1078
+ # hps.data.filter_length,
1079
+ # hps.data.n_mel_channels,
1080
+ # hps.data.sampling_rate,
1081
+ # hps.data.mel_fmin,
1082
+ # hps.data.mel_fmax,
1083
+ # )
1084
+ # y_hat_mel = mel_spectrogram_torch(
1085
+ # y_hat.squeeze(1).float(),
1086
+ # hps.data.filter_length,
1087
+ # hps.data.n_mel_channels,
1088
+ # hps.data.sampling_rate,
1089
+ # hps.data.hop_length,
1090
+ # hps.data.win_length,
1091
+ # hps.data.mel_fmin,
1092
+ # hps.data.mel_fmax,
1093
+ # )
1094
+ # image_dict.update(
1095
+ # {
1096
+ # f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
1097
+ # y_hat_mel[0].cpu().numpy()
1098
+ # )
1099
+ # }
1100
+ # )
1101
+ # image_dict.update(
1102
+ # {
1103
+ # f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
1104
+ # mel[0].cpu().numpy()
1105
+ # )
1106
+ # }
1107
+ # )
1108
+ audio_dict.update(
1109
+ {
1110
+ f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
1111
+ 0, :, : y_hat_lengths[0]
1112
+ ]
1113
+ }
1114
+ )
1115
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
1116
+
1117
+ utils.summarize(
1118
+ writer=writer_eval,
1119
+ global_step=global_step,
1120
+ images=image_dict,
1121
+ audios=audio_dict,
1122
+ audio_sampling_rate=hps.data.sampling_rate,
1123
+ )
1124
+ generator.train()
1125
+
1126
+
1127
+ if __name__ == "__main__":
1128
+ run()