Spaces:
Running
Running
Upload 30 files
Browse files- bert_gen.py +25 -18
- config.yml +3 -3
- data_utils.py +7 -24
- export_onnx.py +3 -1
- hiyoriUI.py +725 -0
- infer.py +90 -35
- losses.py +95 -0
- models.py +66 -67
- modules.py +1 -1
- onnx_infer.py +60 -0
- re_matching.py +0 -1
- resample.py +10 -6
- resample_legacy.py +71 -0
- server.py +733 -103
- test.py +36 -0
- train_ms.py +176 -63
- utils.py +5 -1
- webui.py +211 -174
- webui_preprocess.py +10 -21
bert_gen.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1 |
-
import argparse
|
2 |
-
from multiprocessing import Pool, cpu_count
|
3 |
-
|
4 |
import torch
|
5 |
-
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
import commons
|
9 |
import utils
|
|
|
|
|
|
|
|
|
10 |
from config import config
|
11 |
-
from text import cleaned_text_to_sequence, get_bert
|
12 |
|
13 |
|
14 |
-
def process_line(
|
|
|
15 |
device = config.bert_gen_config.device
|
16 |
if config.bert_gen_config.use_multi_device:
|
17 |
rank = mp.current_process()._identity
|
@@ -28,18 +27,19 @@ def process_line(line):
|
|
28 |
word2ph = [i for i in word2ph]
|
29 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
37 |
|
38 |
bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
|
39 |
|
40 |
try:
|
41 |
bert = torch.load(bert_path)
|
42 |
-
assert bert.shape[
|
43 |
except Exception:
|
44 |
bert = get_bert(text, word2ph, language_str, device)
|
45 |
assert bert.shape[-1] == len(phone)
|
@@ -59,16 +59,23 @@ if __name__ == "__main__":
|
|
59 |
args, _ = parser.parse_known_args()
|
60 |
config_path = args.config
|
61 |
hps = utils.get_hparams_from_file(config_path)
|
|
|
62 |
lines = []
|
63 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
64 |
lines.extend(f.readlines())
|
65 |
|
66 |
with open(hps.data.validation_files, encoding="utf-8") as f:
|
67 |
lines.extend(f.readlines())
|
|
|
|
|
68 |
if len(lines) != 0:
|
69 |
-
num_processes =
|
70 |
with Pool(processes=num_processes) as pool:
|
71 |
-
for _ in tqdm(
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
|
74 |
print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from multiprocessing import Pool
|
|
|
|
|
3 |
import commons
|
4 |
import utils
|
5 |
+
from tqdm import tqdm
|
6 |
+
from text import check_bert_models, cleaned_text_to_sequence, get_bert
|
7 |
+
import argparse
|
8 |
+
import torch.multiprocessing as mp
|
9 |
from config import config
|
|
|
10 |
|
11 |
|
12 |
+
def process_line(x):
|
13 |
+
line, add_blank = x
|
14 |
device = config.bert_gen_config.device
|
15 |
if config.bert_gen_config.use_multi_device:
|
16 |
rank = mp.current_process()._identity
|
|
|
27 |
word2ph = [i for i in word2ph]
|
28 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
29 |
|
30 |
+
if add_blank:
|
31 |
+
phone = commons.intersperse(phone, 0)
|
32 |
+
tone = commons.intersperse(tone, 0)
|
33 |
+
language = commons.intersperse(language, 0)
|
34 |
+
for i in range(len(word2ph)):
|
35 |
+
word2ph[i] = word2ph[i] * 2
|
36 |
+
word2ph[0] += 1
|
37 |
|
38 |
bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
|
39 |
|
40 |
try:
|
41 |
bert = torch.load(bert_path)
|
42 |
+
assert bert.shape[0] == 2048
|
43 |
except Exception:
|
44 |
bert = get_bert(text, word2ph, language_str, device)
|
45 |
assert bert.shape[-1] == len(phone)
|
|
|
59 |
args, _ = parser.parse_known_args()
|
60 |
config_path = args.config
|
61 |
hps = utils.get_hparams_from_file(config_path)
|
62 |
+
check_bert_models()
|
63 |
lines = []
|
64 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
65 |
lines.extend(f.readlines())
|
66 |
|
67 |
with open(hps.data.validation_files, encoding="utf-8") as f:
|
68 |
lines.extend(f.readlines())
|
69 |
+
add_blank = [hps.data.add_blank] * len(lines)
|
70 |
+
|
71 |
if len(lines) != 0:
|
72 |
+
num_processes = args.num_processes
|
73 |
with Pool(processes=num_processes) as pool:
|
74 |
+
for _ in tqdm(
|
75 |
+
pool.imap_unordered(process_line, zip(lines, add_blank)),
|
76 |
+
total=len(lines),
|
77 |
+
):
|
78 |
+
# 这里是缩进的代码块,表示循环体
|
79 |
+
pass # 使用pass语句作为占位符
|
80 |
|
81 |
print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
|
config.yml
CHANGED
@@ -83,7 +83,7 @@ train_ms:
|
|
83 |
base:
|
84 |
use_base_model: false
|
85 |
repo_id: "Stardust_minus/Bert-VITS2"
|
86 |
-
model_image: "Bert-VITS2_2.
|
87 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
88 |
model: "models"
|
89 |
# 配置文件路径
|
@@ -172,6 +172,6 @@ server:
|
|
172 |
# 请不要在github等网站公开分享你的app id 与 key
|
173 |
translate:
|
174 |
# 你的APPID
|
175 |
-
"app_key": ""
|
176 |
# 你的密钥
|
177 |
-
"secret_key": ""
|
|
|
83 |
base:
|
84 |
use_base_model: false
|
85 |
repo_id: "Stardust_minus/Bert-VITS2"
|
86 |
+
model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
|
87 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
88 |
model: "models"
|
89 |
# 配置文件路径
|
|
|
172 |
# 请不要在github等网站公开分享你的app id 与 key
|
173 |
translate:
|
174 |
# 你的APPID
|
175 |
+
"app_key": "20231117001883321"
|
176 |
# 你的密钥
|
177 |
+
"secret_key": "lMQbvZHeJveDceLof2wf"
|
data_utils.py
CHANGED
@@ -3,7 +3,6 @@ import random
|
|
3 |
import torch
|
4 |
import torch.utils.data
|
5 |
from tqdm import tqdm
|
6 |
-
import numpy as np
|
7 |
from tools.log import logger
|
8 |
import commons
|
9 |
from mel_processing import spectrogram_torch, mel_spectrogram_torch
|
@@ -44,10 +43,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
44 |
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
45 |
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
46 |
|
47 |
-
self.empty_emo = torch.squeeze(
|
48 |
-
torch.load("empty_emo.npy", map_location="cpu"), dim=1
|
49 |
-
)
|
50 |
-
|
51 |
random.seed(1234)
|
52 |
random.shuffle(self.audiopaths_sid_text)
|
53 |
self._filter()
|
@@ -98,14 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
98 |
spec, wav = self.get_audio(audiopath)
|
99 |
sid = torch.LongTensor([int(self.spk_map[sid])])
|
100 |
|
101 |
-
|
102 |
-
emo = torch.squeeze(
|
103 |
-
torch.load(audiopath.replace(".wav", ".emo.npy"), map_location="cpu"),
|
104 |
-
dim=1,
|
105 |
-
)
|
106 |
-
else:
|
107 |
-
emo = self.empty_emo
|
108 |
-
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
|
109 |
|
110 |
def get_audio(self, filename):
|
111 |
audio, sampling_rate = load_wav_to_torch(filename)
|
@@ -168,15 +156,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
168 |
|
169 |
if language_str == "ZH":
|
170 |
bert = bert_ori
|
171 |
-
ja_bert = torch.
|
172 |
-
en_bert = torch.
|
173 |
elif language_str == "JP":
|
174 |
-
bert = torch.
|
175 |
ja_bert = bert_ori
|
176 |
-
en_bert = torch.
|
177 |
elif language_str == "EN":
|
178 |
-
bert = torch.
|
179 |
-
ja_bert = torch.
|
180 |
en_bert = bert_ori
|
181 |
phone = torch.LongTensor(phone)
|
182 |
tone = torch.LongTensor(tone)
|
@@ -226,7 +214,6 @@ class TextAudioSpeakerCollate:
|
|
226 |
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
227 |
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
228 |
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
229 |
-
emo = torch.FloatTensor(len(batch), 512)
|
230 |
|
231 |
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
232 |
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
@@ -238,7 +225,6 @@ class TextAudioSpeakerCollate:
|
|
238 |
bert_padded.zero_()
|
239 |
ja_bert_padded.zero_()
|
240 |
en_bert_padded.zero_()
|
241 |
-
emo.zero_()
|
242 |
|
243 |
for i in range(len(ids_sorted_decreasing)):
|
244 |
row = batch[ids_sorted_decreasing[i]]
|
@@ -272,8 +258,6 @@ class TextAudioSpeakerCollate:
|
|
272 |
en_bert = row[8]
|
273 |
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
274 |
|
275 |
-
emo[i, :] = row[9]
|
276 |
-
|
277 |
return (
|
278 |
text_padded,
|
279 |
text_lengths,
|
@@ -287,7 +271,6 @@ class TextAudioSpeakerCollate:
|
|
287 |
bert_padded,
|
288 |
ja_bert_padded,
|
289 |
en_bert_padded,
|
290 |
-
emo,
|
291 |
)
|
292 |
|
293 |
|
|
|
3 |
import torch
|
4 |
import torch.utils.data
|
5 |
from tqdm import tqdm
|
|
|
6 |
from tools.log import logger
|
7 |
import commons
|
8 |
from mel_processing import spectrogram_torch, mel_spectrogram_torch
|
|
|
43 |
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
44 |
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
45 |
|
|
|
|
|
|
|
|
|
46 |
random.seed(1234)
|
47 |
random.shuffle(self.audiopaths_sid_text)
|
48 |
self._filter()
|
|
|
93 |
spec, wav = self.get_audio(audiopath)
|
94 |
sid = torch.LongTensor([int(self.spk_map[sid])])
|
95 |
|
96 |
+
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def get_audio(self, filename):
|
99 |
audio, sampling_rate = load_wav_to_torch(filename)
|
|
|
156 |
|
157 |
if language_str == "ZH":
|
158 |
bert = bert_ori
|
159 |
+
ja_bert = torch.randn(1024, len(phone))
|
160 |
+
en_bert = torch.randn(1024, len(phone))
|
161 |
elif language_str == "JP":
|
162 |
+
bert = torch.randn(1024, len(phone))
|
163 |
ja_bert = bert_ori
|
164 |
+
en_bert = torch.randn(1024, len(phone))
|
165 |
elif language_str == "EN":
|
166 |
+
bert = torch.randn(1024, len(phone))
|
167 |
+
ja_bert = torch.randn(1024, len(phone))
|
168 |
en_bert = bert_ori
|
169 |
phone = torch.LongTensor(phone)
|
170 |
tone = torch.LongTensor(tone)
|
|
|
214 |
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
215 |
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
216 |
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
|
|
217 |
|
218 |
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
219 |
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
|
|
225 |
bert_padded.zero_()
|
226 |
ja_bert_padded.zero_()
|
227 |
en_bert_padded.zero_()
|
|
|
228 |
|
229 |
for i in range(len(ids_sorted_decreasing)):
|
230 |
row = batch[ids_sorted_decreasing[i]]
|
|
|
258 |
en_bert = row[8]
|
259 |
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
260 |
|
|
|
|
|
261 |
return (
|
262 |
text_padded,
|
263 |
text_lengths,
|
|
|
271 |
bert_padded,
|
272 |
ja_bert_padded,
|
273 |
en_bert_padded,
|
|
|
274 |
)
|
275 |
|
276 |
|
export_onnx.py
CHANGED
@@ -5,8 +5,10 @@ if __name__ == "__main__":
|
|
5 |
export_path = "BertVits2.2PT"
|
6 |
model_path = "model\\G_0.pth"
|
7 |
config_path = "model\\config.json"
|
|
|
|
|
8 |
if not os.path.exists("onnx"):
|
9 |
os.makedirs("onnx")
|
10 |
if not os.path.exists(f"onnx/{export_path}"):
|
11 |
os.makedirs(f"onnx/{export_path}")
|
12 |
-
export_onnx(export_path, model_path, config_path)
|
|
|
5 |
export_path = "BertVits2.2PT"
|
6 |
model_path = "model\\G_0.pth"
|
7 |
config_path = "model\\config.json"
|
8 |
+
novq = False
|
9 |
+
dev = False
|
10 |
if not os.path.exists("onnx"):
|
11 |
os.makedirs("onnx")
|
12 |
if not os.path.exists(f"onnx/{export_path}"):
|
13 |
os.makedirs(f"onnx/{export_path}")
|
14 |
+
export_onnx(export_path, model_path, config_path, novq, dev)
|
hiyoriUI.py
ADDED
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
api服务,网页后端 多版本多模型 fastapi实现
|
3 |
+
原 server_fastapi
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import gc
|
7 |
+
import random
|
8 |
+
import librosa
|
9 |
+
import gradio
|
10 |
+
import numpy as np
|
11 |
+
import utils
|
12 |
+
from fastapi import FastAPI, Query, Request, File, UploadFile, Form
|
13 |
+
from fastapi.responses import Response, FileResponse
|
14 |
+
from fastapi.staticfiles import StaticFiles
|
15 |
+
from io import BytesIO
|
16 |
+
from scipy.io import wavfile
|
17 |
+
import uvicorn
|
18 |
+
import torch
|
19 |
+
import webbrowser
|
20 |
+
import psutil
|
21 |
+
import GPUtil
|
22 |
+
from typing import Dict, Optional, List, Set, Union, Tuple
|
23 |
+
import os
|
24 |
+
from tools.log import logger
|
25 |
+
from urllib.parse import unquote
|
26 |
+
|
27 |
+
from infer import infer, get_net_g, latest_version
|
28 |
+
import tools.translate as trans
|
29 |
+
from tools.sentence import split_by_language
|
30 |
+
from re_matching import cut_sent
|
31 |
+
|
32 |
+
|
33 |
+
from config import config
|
34 |
+
|
35 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
36 |
+
|
37 |
+
|
38 |
+
class Model:
|
39 |
+
"""模型封装类"""
|
40 |
+
|
41 |
+
def __init__(self, config_path: str, model_path: str, device: str, language: str):
|
42 |
+
self.config_path: str = os.path.normpath(config_path)
|
43 |
+
self.model_path: str = os.path.normpath(model_path)
|
44 |
+
self.device: str = device
|
45 |
+
self.language: str = language
|
46 |
+
self.hps = utils.get_hparams_from_file(config_path)
|
47 |
+
self.spk2id: Dict[str, int] = self.hps.data.spk2id # spk - id 映射字典
|
48 |
+
self.id2spk: Dict[int, str] = dict() # id - spk 映射字典
|
49 |
+
for speaker, speaker_id in self.hps.data.spk2id.items():
|
50 |
+
self.id2spk[speaker_id] = speaker
|
51 |
+
self.version: str = (
|
52 |
+
self.hps.version if hasattr(self.hps, "version") else latest_version
|
53 |
+
)
|
54 |
+
self.net_g = get_net_g(
|
55 |
+
model_path=model_path,
|
56 |
+
version=self.version,
|
57 |
+
device=device,
|
58 |
+
hps=self.hps,
|
59 |
+
)
|
60 |
+
|
61 |
+
def to_dict(self) -> Dict[str, any]:
|
62 |
+
return {
|
63 |
+
"config_path": self.config_path,
|
64 |
+
"model_path": self.model_path,
|
65 |
+
"device": self.device,
|
66 |
+
"language": self.language,
|
67 |
+
"spk2id": self.spk2id,
|
68 |
+
"id2spk": self.id2spk,
|
69 |
+
"version": self.version,
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
class Models:
|
74 |
+
def __init__(self):
|
75 |
+
self.models: Dict[int, Model] = dict()
|
76 |
+
self.num = 0
|
77 |
+
# spkInfo[角色名][模型id] = 角色id
|
78 |
+
self.spk_info: Dict[str, Dict[int, int]] = dict()
|
79 |
+
self.path2ids: Dict[str, Set[int]] = dict() # 路径指向的model的id
|
80 |
+
|
81 |
+
def init_model(
|
82 |
+
self, config_path: str, model_path: str, device: str, language: str
|
83 |
+
) -> int:
|
84 |
+
"""
|
85 |
+
初始化并添加一个模型
|
86 |
+
|
87 |
+
:param config_path: 模型config.json路径
|
88 |
+
:param model_path: 模型路径
|
89 |
+
:param device: 模型推理使用设备
|
90 |
+
:param language: 模型推理默认语言
|
91 |
+
"""
|
92 |
+
# 若文件不存在则不进行加载
|
93 |
+
if not os.path.isfile(model_path):
|
94 |
+
if model_path != "":
|
95 |
+
logger.warning(f"模型文件{model_path} 不存在,不进行初始化")
|
96 |
+
return self.num
|
97 |
+
if not os.path.isfile(config_path):
|
98 |
+
if config_path != "":
|
99 |
+
logger.warning(f"配置文件{config_path} 不存在,不进行初始化")
|
100 |
+
return self.num
|
101 |
+
|
102 |
+
# 若路径中的模型已存在,则不添加模型,若不存在,则进行初始化。
|
103 |
+
model_path = os.path.realpath(model_path)
|
104 |
+
if model_path not in self.path2ids.keys():
|
105 |
+
self.path2ids[model_path] = {self.num}
|
106 |
+
self.models[self.num] = Model(
|
107 |
+
config_path=config_path,
|
108 |
+
model_path=model_path,
|
109 |
+
device=device,
|
110 |
+
language=language,
|
111 |
+
)
|
112 |
+
logger.success(f"添加模型{model_path},使用配置文件{os.path.realpath(config_path)}")
|
113 |
+
else:
|
114 |
+
# 获取一个指向id
|
115 |
+
m_id = next(iter(self.path2ids[model_path]))
|
116 |
+
self.models[self.num] = self.models[m_id]
|
117 |
+
self.path2ids[model_path].add(self.num)
|
118 |
+
logger.success("模型已存在,添加模型引用。")
|
119 |
+
# 添加角色信息
|
120 |
+
for speaker, speaker_id in self.models[self.num].spk2id.items():
|
121 |
+
if speaker not in self.spk_info.keys():
|
122 |
+
self.spk_info[speaker] = {self.num: speaker_id}
|
123 |
+
else:
|
124 |
+
self.spk_info[speaker][self.num] = speaker_id
|
125 |
+
# 修改计数
|
126 |
+
self.num += 1
|
127 |
+
return self.num - 1
|
128 |
+
|
129 |
+
def del_model(self, index: int) -> Optional[int]:
|
130 |
+
"""删除对应序号的模型,若不存在则返回None"""
|
131 |
+
if index not in self.models.keys():
|
132 |
+
return None
|
133 |
+
# 删除角色信息
|
134 |
+
for speaker, speaker_id in self.models[index].spk2id.items():
|
135 |
+
self.spk_info[speaker].pop(index)
|
136 |
+
if len(self.spk_info[speaker]) == 0:
|
137 |
+
# 若对应角色的所有模型都被删除,则清除该角色信息
|
138 |
+
self.spk_info.pop(speaker)
|
139 |
+
# 删除路径信息
|
140 |
+
model_path = os.path.realpath(self.models[index].model_path)
|
141 |
+
self.path2ids[model_path].remove(index)
|
142 |
+
if len(self.path2ids[model_path]) == 0:
|
143 |
+
self.path2ids.pop(model_path)
|
144 |
+
logger.success(f"删除模型{model_path}, id = {index}")
|
145 |
+
else:
|
146 |
+
logger.success(f"删除模型引用{model_path}, id = {index}")
|
147 |
+
# 删除模型
|
148 |
+
self.models.pop(index)
|
149 |
+
gc.collect()
|
150 |
+
if torch.cuda.is_available():
|
151 |
+
torch.cuda.empty_cache()
|
152 |
+
return index
|
153 |
+
|
154 |
+
def get_models(self):
|
155 |
+
"""获取所有模型"""
|
156 |
+
return self.models
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
app = FastAPI()
|
161 |
+
app.logger = logger
|
162 |
+
# 挂载静态文件
|
163 |
+
logger.info("开始挂载网页页面")
|
164 |
+
StaticDir: str = "./Web"
|
165 |
+
if not os.path.isdir(StaticDir):
|
166 |
+
logger.warning(
|
167 |
+
"缺少网页资源,无法开启网页页面,如有需要请在 https://github.com/jiangyuxiaoxiao/Bert-VITS2-UI 或者Bert-VITS对应版本的release页面下载"
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
dirs = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
|
171 |
+
files = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
|
172 |
+
for dirName in dirs:
|
173 |
+
app.mount(
|
174 |
+
f"/{dirName}",
|
175 |
+
StaticFiles(directory=f"./{StaticDir}/{dirName}"),
|
176 |
+
name=dirName,
|
177 |
+
)
|
178 |
+
loaded_models = Models()
|
179 |
+
# 加载模型
|
180 |
+
logger.info("开始加载模型")
|
181 |
+
models_info = config.server_config.models
|
182 |
+
for model_info in models_info:
|
183 |
+
loaded_models.init_model(
|
184 |
+
config_path=model_info["config"],
|
185 |
+
model_path=model_info["model"],
|
186 |
+
device=model_info["device"],
|
187 |
+
language=model_info["language"],
|
188 |
+
)
|
189 |
+
|
190 |
+
@app.get("/")
|
191 |
+
async def index():
|
192 |
+
return FileResponse("./Web/index.html")
|
193 |
+
|
194 |
+
async def _voice(
|
195 |
+
text: str,
|
196 |
+
model_id: int,
|
197 |
+
speaker_name: str,
|
198 |
+
speaker_id: int,
|
199 |
+
sdp_ratio: float,
|
200 |
+
noise: float,
|
201 |
+
noisew: float,
|
202 |
+
length: float,
|
203 |
+
language: str,
|
204 |
+
auto_translate: bool,
|
205 |
+
auto_split: bool,
|
206 |
+
emotion: Optional[Union[int, str]] = None,
|
207 |
+
reference_audio=None,
|
208 |
+
style_text: Optional[str] = None,
|
209 |
+
style_weight: float = 0.7,
|
210 |
+
) -> Union[Response, Dict[str, any]]:
|
211 |
+
"""TTS实现函数"""
|
212 |
+
|
213 |
+
# 检查
|
214 |
+
# 检查模型是否存在
|
215 |
+
if model_id not in loaded_models.models.keys():
|
216 |
+
logger.error(f"/voice 请求错误:模型model_id={model_id}未加载")
|
217 |
+
return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
|
218 |
+
# 检查是否提供speaker
|
219 |
+
if speaker_name is None and speaker_id is None:
|
220 |
+
logger.error("/voice 请求错误:推理请求未提供speaker_name或speaker_id")
|
221 |
+
return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
|
222 |
+
elif speaker_name is None:
|
223 |
+
# 检查speaker_id是否存在
|
224 |
+
if speaker_id not in loaded_models.models[model_id].id2spk.keys():
|
225 |
+
logger.error(f"/voice 请求错误:角色speaker_id={speaker_id}不存在")
|
226 |
+
return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
|
227 |
+
speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
|
228 |
+
# 检查speaker_name是否存在
|
229 |
+
if speaker_name not in loaded_models.models[model_id].spk2id.keys():
|
230 |
+
logger.error(f"/voice 请求错误:角色speaker_name={speaker_name}不存在")
|
231 |
+
return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
|
232 |
+
# 未传入则使用默认语言
|
233 |
+
if language is None:
|
234 |
+
language = loaded_models.models[model_id].language
|
235 |
+
# 翻译会破坏mix结构,auto也会变得无意义。不要在这两个模式下使用
|
236 |
+
if auto_translate:
|
237 |
+
if language == "auto" or language == "mix":
|
238 |
+
logger.error(
|
239 |
+
f"/voice 请求错误:请勿同时使用language = {language}与auto_translate模式"
|
240 |
+
)
|
241 |
+
return {
|
242 |
+
"status": 20,
|
243 |
+
"detail": f"请勿同时使用language = {language}与auto_translate模式",
|
244 |
+
}
|
245 |
+
text = trans.translate(Sentence=text, to_Language=language.lower())
|
246 |
+
if reference_audio is not None:
|
247 |
+
ref_audio = BytesIO(await reference_audio.read())
|
248 |
+
# 2.2 适配
|
249 |
+
if loaded_models.models[model_id].version == "2.2":
|
250 |
+
ref_audio, _ = librosa.load(ref_audio, 48000)
|
251 |
+
else:
|
252 |
+
ref_audio = reference_audio
|
253 |
+
|
254 |
+
# 改动:增加使用 || 对文本进行主动切分
|
255 |
+
# 切分优先级: || → auto/mix → auto_split
|
256 |
+
text2 = text.replace("\n", "").lstrip()
|
257 |
+
texts: List[str] = text2.split("||")
|
258 |
+
|
259 |
+
# 对于mix和auto的说明:出于版本兼容性���考虑,暂时无法使用multilang的方式进行推理
|
260 |
+
if language == "MIX":
|
261 |
+
text_language_speakers: List[Tuple[str, str, str]] = []
|
262 |
+
for _text in texts:
|
263 |
+
speaker_pieces = _text.split("[") # 按说话人分割多块
|
264 |
+
for speaker_piece in speaker_pieces:
|
265 |
+
if speaker_piece == "":
|
266 |
+
continue
|
267 |
+
speaker_piece2 = speaker_piece.split("]")
|
268 |
+
if len(speaker_piece2) != 2:
|
269 |
+
return {
|
270 |
+
"status": 21,
|
271 |
+
"detail": "MIX语法错误",
|
272 |
+
}
|
273 |
+
speaker = speaker_piece2[0].strip()
|
274 |
+
lang_pieces = speaker_piece2[1].split("<")
|
275 |
+
for lang_piece in lang_pieces:
|
276 |
+
if lang_piece == "":
|
277 |
+
continue
|
278 |
+
lang_piece2 = lang_piece.split(">")
|
279 |
+
if len(lang_piece2) != 2:
|
280 |
+
return {
|
281 |
+
"status": 21,
|
282 |
+
"detail": "MIX语法错误",
|
283 |
+
}
|
284 |
+
lang = lang_piece2[0].strip()
|
285 |
+
if lang.upper() not in ["ZH", "EN", "JP"]:
|
286 |
+
return {
|
287 |
+
"status": 21,
|
288 |
+
"detail": "MIX语法错误",
|
289 |
+
}
|
290 |
+
t = lang_piece2[1]
|
291 |
+
text_language_speakers.append((t, lang.upper(), speaker))
|
292 |
+
|
293 |
+
elif language == "AUTO":
|
294 |
+
text_language_speakers: List[Tuple[str, str, str]] = [
|
295 |
+
(final_text, language.upper().replace("JA", "JP"), speaker_name)
|
296 |
+
for sub_list in [
|
297 |
+
split_by_language(_text, target_languages=["zh", "ja", "en"])
|
298 |
+
for _text in texts
|
299 |
+
if _text != ""
|
300 |
+
]
|
301 |
+
for final_text, language in sub_list
|
302 |
+
if final_text != ""
|
303 |
+
]
|
304 |
+
else:
|
305 |
+
text_language_speakers: List[Tuple[str, str, str]] = [
|
306 |
+
(_text, language, speaker_name) for _text in texts if _text != ""
|
307 |
+
]
|
308 |
+
|
309 |
+
if auto_split:
|
310 |
+
text_language_speakers: List[Tuple[str, str, str]] = [
|
311 |
+
(final_text, lang, speaker)
|
312 |
+
for _text, lang, speaker in text_language_speakers
|
313 |
+
for final_text in cut_sent(_text)
|
314 |
+
]
|
315 |
+
|
316 |
+
audios = []
|
317 |
+
with torch.no_grad():
|
318 |
+
for _text, lang, speaker in text_language_speakers:
|
319 |
+
audios.append(
|
320 |
+
infer(
|
321 |
+
text=_text,
|
322 |
+
sdp_ratio=sdp_ratio,
|
323 |
+
noise_scale=noise,
|
324 |
+
noise_scale_w=noisew,
|
325 |
+
length_scale=length,
|
326 |
+
sid=speaker,
|
327 |
+
language=lang,
|
328 |
+
hps=loaded_models.models[model_id].hps,
|
329 |
+
net_g=loaded_models.models[model_id].net_g,
|
330 |
+
device=loaded_models.models[model_id].device,
|
331 |
+
emotion=emotion,
|
332 |
+
reference_audio=ref_audio,
|
333 |
+
style_text=style_text,
|
334 |
+
style_weight=style_weight,
|
335 |
+
)
|
336 |
+
)
|
337 |
+
# audios.append(np.zeros(int(44100 * 0.2)))
|
338 |
+
# audios.pop()
|
339 |
+
audio = np.concatenate(audios)
|
340 |
+
audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
|
341 |
+
with BytesIO() as wavContent:
|
342 |
+
wavfile.write(
|
343 |
+
wavContent, loaded_models.models[model_id].hps.data.sampling_rate, audio
|
344 |
+
)
|
345 |
+
response = Response(content=wavContent.getvalue(), media_type="audio/wav")
|
346 |
+
return response
|
347 |
+
|
348 |
+
@app.post("/voice")
|
349 |
+
async def voice(
|
350 |
+
request: Request, # fastapi自动注入
|
351 |
+
text: str = Form(...),
|
352 |
+
model_id: int = Query(..., description="模型ID"), # 模型序号
|
353 |
+
speaker_name: str = Query(
|
354 |
+
None, description="说话人名"
|
355 |
+
), # speaker_name与 speaker_id二者选其一
|
356 |
+
speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
|
357 |
+
sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
|
358 |
+
noise: float = Query(0.2, description="感情"),
|
359 |
+
noisew: float = Query(0.9, description="音素长度"),
|
360 |
+
length: float = Query(1, description="语速"),
|
361 |
+
language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
|
362 |
+
auto_translate: bool = Query(False, description="自动翻译"),
|
363 |
+
auto_split: bool = Query(False, description="自动切分"),
|
364 |
+
emotion: Optional[Union[int, str]] = Query(None, description="emo"),
|
365 |
+
reference_audio: UploadFile = File(None),
|
366 |
+
style_text: Optional[str] = Form(None, description="风格文本"),
|
367 |
+
style_weight: float = Query(0.7, description="风格权重"),
|
368 |
+
):
|
369 |
+
"""语音接口,若需要上传参考音频请仅使用post请求"""
|
370 |
+
logger.info(
|
371 |
+
f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )} text={text}"
|
372 |
+
)
|
373 |
+
return await _voice(
|
374 |
+
text=text,
|
375 |
+
model_id=model_id,
|
376 |
+
speaker_name=speaker_name,
|
377 |
+
speaker_id=speaker_id,
|
378 |
+
sdp_ratio=sdp_ratio,
|
379 |
+
noise=noise,
|
380 |
+
noisew=noisew,
|
381 |
+
length=length,
|
382 |
+
language=language,
|
383 |
+
auto_translate=auto_translate,
|
384 |
+
auto_split=auto_split,
|
385 |
+
emotion=emotion,
|
386 |
+
reference_audio=reference_audio,
|
387 |
+
style_text=style_text,
|
388 |
+
style_weight=style_weight,
|
389 |
+
)
|
390 |
+
|
391 |
+
@app.get("/voice")
|
392 |
+
async def voice(
|
393 |
+
request: Request, # fastapi自动注入
|
394 |
+
text: str = Query(..., description="输入文字"),
|
395 |
+
model_id: int = Query(..., description="模型ID"), # 模型序号
|
396 |
+
speaker_name: str = Query(
|
397 |
+
None, description="说话人名"
|
398 |
+
), # speaker_name与 speaker_id二者选其一
|
399 |
+
speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
|
400 |
+
sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
|
401 |
+
noise: float = Query(0.2, description="感情"),
|
402 |
+
noisew: float = Query(0.9, description="音素长度"),
|
403 |
+
length: float = Query(1, description="语速"),
|
404 |
+
language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
|
405 |
+
auto_translate: bool = Query(False, description="自动翻译"),
|
406 |
+
auto_split: bool = Query(False, description="自动切分"),
|
407 |
+
emotion: Optional[Union[int, str]] = Query(None, description="emo"),
|
408 |
+
style_text: Optional[str] = Query(None, description="风格文本"),
|
409 |
+
style_weight: float = Query(0.7, description="风格权重"),
|
410 |
+
):
|
411 |
+
"""语音接口,不建议使用"""
|
412 |
+
logger.info(
|
413 |
+
f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
|
414 |
+
)
|
415 |
+
return await _voice(
|
416 |
+
text=text,
|
417 |
+
model_id=model_id,
|
418 |
+
speaker_name=speaker_name,
|
419 |
+
speaker_id=speaker_id,
|
420 |
+
sdp_ratio=sdp_ratio,
|
421 |
+
noise=noise,
|
422 |
+
noisew=noisew,
|
423 |
+
length=length,
|
424 |
+
language=language,
|
425 |
+
auto_translate=auto_translate,
|
426 |
+
auto_split=auto_split,
|
427 |
+
emotion=emotion,
|
428 |
+
style_text=style_text,
|
429 |
+
style_weight=style_weight,
|
430 |
+
)
|
431 |
+
|
432 |
+
@app.get("/models/info")
|
433 |
+
def get_loaded_models_info(request: Request):
|
434 |
+
"""获取已加载模型信息"""
|
435 |
+
|
436 |
+
result: Dict[str, Dict] = dict()
|
437 |
+
for key, model in loaded_models.models.items():
|
438 |
+
result[str(key)] = model.to_dict()
|
439 |
+
return result
|
440 |
+
|
441 |
+
@app.get("/models/delete")
|
442 |
+
def delete_model(
|
443 |
+
request: Request, model_id: int = Query(..., description="删除模型id")
|
444 |
+
):
|
445 |
+
"""删除指定模型"""
|
446 |
+
logger.info(
|
447 |
+
f"{request.client.host}:{request.client.port}/models/delete { unquote(str(request.query_params) )}"
|
448 |
+
)
|
449 |
+
result = loaded_models.del_model(model_id)
|
450 |
+
if result is None:
|
451 |
+
logger.error(f"/models/delete 模型删除错误:模型{model_id}不存在,删除失败")
|
452 |
+
return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
|
453 |
+
|
454 |
+
return {"status": 0, "detail": "删除成功"}
|
455 |
+
|
456 |
+
@app.get("/models/add")
|
457 |
+
def add_model(
|
458 |
+
request: Request,
|
459 |
+
model_path: str = Query(..., description="添加模型路径"),
|
460 |
+
config_path: str = Query(
|
461 |
+
None, description="添加模型配置文件路径,不填则使用./config.json或../config.json"
|
462 |
+
),
|
463 |
+
device: str = Query("cuda", description="推理使用设备"),
|
464 |
+
language: str = Query("ZH", description="模型默认语言"),
|
465 |
+
):
|
466 |
+
"""添加指定模型:允许重复添加相同路径模型,且不重复占用内存"""
|
467 |
+
logger.info(
|
468 |
+
f"{request.client.host}:{request.client.port}/models/add { unquote(str(request.query_params) )}"
|
469 |
+
)
|
470 |
+
if config_path is None:
|
471 |
+
model_dir = os.path.dirname(model_path)
|
472 |
+
if os.path.isfile(os.path.join(model_dir, "config.json")):
|
473 |
+
config_path = os.path.join(model_dir, "config.json")
|
474 |
+
elif os.path.isfile(os.path.join(model_dir, "../config.json")):
|
475 |
+
config_path = os.path.join(model_dir, "../config.json")
|
476 |
+
else:
|
477 |
+
logger.error("/models/add 模型添加失败:未在模型所在目录以及上级目录找到config.json文件")
|
478 |
+
return {
|
479 |
+
"status": 15,
|
480 |
+
"detail": "查询未传���配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
|
481 |
+
}
|
482 |
+
try:
|
483 |
+
model_id = loaded_models.init_model(
|
484 |
+
config_path=config_path,
|
485 |
+
model_path=model_path,
|
486 |
+
device=device,
|
487 |
+
language=language,
|
488 |
+
)
|
489 |
+
except Exception:
|
490 |
+
logging.exception("模型加载出错")
|
491 |
+
return {
|
492 |
+
"status": 16,
|
493 |
+
"detail": "模型加载出错,详细查看日志",
|
494 |
+
}
|
495 |
+
return {
|
496 |
+
"status": 0,
|
497 |
+
"detail": "模型添加成功",
|
498 |
+
"Data": {
|
499 |
+
"model_id": model_id,
|
500 |
+
"model_info": loaded_models.models[model_id].to_dict(),
|
501 |
+
},
|
502 |
+
}
|
503 |
+
|
504 |
+
def _get_all_models(root_dir: str = "Data", only_unloaded: bool = False):
|
505 |
+
"""从root_dir搜索获取所有可用模型"""
|
506 |
+
result: Dict[str, List[str]] = dict()
|
507 |
+
files = os.listdir(root_dir) + ["."]
|
508 |
+
for file in files:
|
509 |
+
if os.path.isdir(os.path.join(root_dir, file)):
|
510 |
+
sub_dir = os.path.join(root_dir, file)
|
511 |
+
# 搜索 "sub_dir" 、 "sub_dir/models" 两个路径
|
512 |
+
result[file] = list()
|
513 |
+
sub_files = os.listdir(sub_dir)
|
514 |
+
model_files = []
|
515 |
+
for sub_file in sub_files:
|
516 |
+
relpath = os.path.realpath(os.path.join(sub_dir, sub_file))
|
517 |
+
if only_unloaded and relpath in loaded_models.path2ids.keys():
|
518 |
+
continue
|
519 |
+
if sub_file.endswith(".pth") and sub_file.startswith("G_"):
|
520 |
+
if os.path.isfile(relpath):
|
521 |
+
model_files.append(sub_file)
|
522 |
+
# 对模型文件按步数排序
|
523 |
+
model_files = sorted(
|
524 |
+
model_files,
|
525 |
+
key=lambda pth: int(pth.lstrip("G_").rstrip(".pth"))
|
526 |
+
if pth.lstrip("G_").rstrip(".pth").isdigit()
|
527 |
+
else 10**10,
|
528 |
+
)
|
529 |
+
result[file] = model_files
|
530 |
+
models_dir = os.path.join(sub_dir, "models")
|
531 |
+
model_files = []
|
532 |
+
if os.path.isdir(models_dir):
|
533 |
+
sub_files = os.listdir(models_dir)
|
534 |
+
for sub_file in sub_files:
|
535 |
+
relpath = os.path.realpath(os.path.join(models_dir, sub_file))
|
536 |
+
if only_unloaded and relpath in loaded_models.path2ids.keys():
|
537 |
+
continue
|
538 |
+
if sub_file.endswith(".pth") and sub_file.startswith("G_"):
|
539 |
+
if os.path.isfile(os.path.join(models_dir, sub_file)):
|
540 |
+
model_files.append(f"models/{sub_file}")
|
541 |
+
# 对模型文件按步数排序
|
542 |
+
model_files = sorted(
|
543 |
+
model_files,
|
544 |
+
key=lambda pth: int(pth.lstrip("models/G_").rstrip(".pth"))
|
545 |
+
if pth.lstrip("models/G_").rstrip(".pth").isdigit()
|
546 |
+
else 10**10,
|
547 |
+
)
|
548 |
+
result[file] += model_files
|
549 |
+
if len(result[file]) == 0:
|
550 |
+
result.pop(file)
|
551 |
+
|
552 |
+
return result
|
553 |
+
|
554 |
+
@app.get("/models/get_unloaded")
|
555 |
+
def get_unloaded_models_info(
|
556 |
+
request: Request, root_dir: str = Query("Data", description="搜索根目录")
|
557 |
+
):
|
558 |
+
"""获取未加载模型"""
|
559 |
+
logger.info(
|
560 |
+
f"{request.client.host}:{request.client.port}/models/get_unloaded { unquote(str(request.query_params) )}"
|
561 |
+
)
|
562 |
+
return _get_all_models(root_dir, only_unloaded=True)
|
563 |
+
|
564 |
+
@app.get("/models/get_local")
|
565 |
+
def get_local_models_info(
|
566 |
+
request: Request, root_dir: str = Query("Data", description="搜索根目录")
|
567 |
+
):
|
568 |
+
"""获取全部本地模型"""
|
569 |
+
logger.info(
|
570 |
+
f"{request.client.host}:{request.client.port}/models/get_local { unquote(str(request.query_params) )}"
|
571 |
+
)
|
572 |
+
return _get_all_models(root_dir, only_unloaded=False)
|
573 |
+
|
574 |
+
@app.get("/status")
|
575 |
+
def get_status():
|
576 |
+
"""获取电脑运行状态"""
|
577 |
+
cpu_percent = psutil.cpu_percent(interval=1)
|
578 |
+
memory_info = psutil.virtual_memory()
|
579 |
+
memory_total = memory_info.total
|
580 |
+
memory_available = memory_info.available
|
581 |
+
memory_used = memory_info.used
|
582 |
+
memory_percent = memory_info.percent
|
583 |
+
gpuInfo = []
|
584 |
+
devices = ["cpu"]
|
585 |
+
for i in range(torch.cuda.device_count()):
|
586 |
+
devices.append(f"cuda:{i}")
|
587 |
+
gpus = GPUtil.getGPUs()
|
588 |
+
for gpu in gpus:
|
589 |
+
gpuInfo.append(
|
590 |
+
{
|
591 |
+
"gpu_id": gpu.id,
|
592 |
+
"gpu_load": gpu.load,
|
593 |
+
"gpu_memory": {
|
594 |
+
"total": gpu.memoryTotal,
|
595 |
+
"used": gpu.memoryUsed,
|
596 |
+
"free": gpu.memoryFree,
|
597 |
+
},
|
598 |
+
}
|
599 |
+
)
|
600 |
+
return {
|
601 |
+
"devices": devices,
|
602 |
+
"cpu_percent": cpu_percent,
|
603 |
+
"memory_total": memory_total,
|
604 |
+
"memory_available": memory_available,
|
605 |
+
"memory_used": memory_used,
|
606 |
+
"memory_percent": memory_percent,
|
607 |
+
"gpu": gpuInfo,
|
608 |
+
}
|
609 |
+
|
610 |
+
@app.get("/tools/translate")
|
611 |
+
def translate(
|
612 |
+
request: Request,
|
613 |
+
texts: str = Query(..., description="待翻译文本"),
|
614 |
+
to_language: str = Query(..., description="翻译目标语言"),
|
615 |
+
):
|
616 |
+
"""翻译"""
|
617 |
+
logger.info(
|
618 |
+
f"{request.client.host}:{request.client.port}/tools/translate { unquote(str(request.query_params) )}"
|
619 |
+
)
|
620 |
+
return {"texts": trans.translate(Sentence=texts, to_Language=to_language)}
|
621 |
+
|
622 |
+
all_examples: Dict[str, Dict[str, List]] = dict() # 存放示例
|
623 |
+
|
624 |
+
@app.get("/tools/random_example")
|
625 |
+
def random_example(
|
626 |
+
request: Request,
|
627 |
+
language: str = Query(None, description="指定语言,未指定则随机返回"),
|
628 |
+
root_dir: str = Query("Data", description="搜索根目录"),
|
629 |
+
):
|
630 |
+
"""
|
631 |
+
获取一个随机音频+文本,用于对比,音频会从本地目录随机选择。
|
632 |
+
"""
|
633 |
+
logger.info(
|
634 |
+
f"{request.client.host}:{request.client.port}/tools/random_example { unquote(str(request.query_params) )}"
|
635 |
+
)
|
636 |
+
global all_examples
|
637 |
+
# 数据初始化
|
638 |
+
if root_dir not in all_examples.keys():
|
639 |
+
all_examples[root_dir] = {"ZH": [], "JP": [], "EN": []}
|
640 |
+
|
641 |
+
examples = all_examples[root_dir]
|
642 |
+
|
643 |
+
# 从项目Data目录中搜索train/val.list
|
644 |
+
for root, directories, _files in os.walk(root_dir):
|
645 |
+
for file in _files:
|
646 |
+
if file in ["train.list", "val.list"]:
|
647 |
+
with open(
|
648 |
+
os.path.join(root, file), mode="r", encoding="utf-8"
|
649 |
+
) as f:
|
650 |
+
lines = f.readlines()
|
651 |
+
for line in lines:
|
652 |
+
data = line.split("|")
|
653 |
+
if len(data) != 7:
|
654 |
+
continue
|
655 |
+
# 音频存在 且语言为ZH/EN/JP
|
656 |
+
if os.path.isfile(data[0]) and data[2] in [
|
657 |
+
"ZH",
|
658 |
+
"JP",
|
659 |
+
"EN",
|
660 |
+
]:
|
661 |
+
examples[data[2]].append(
|
662 |
+
{
|
663 |
+
"text": data[3],
|
664 |
+
"audio": data[0],
|
665 |
+
"speaker": data[1],
|
666 |
+
}
|
667 |
+
)
|
668 |
+
|
669 |
+
examples = all_examples[root_dir]
|
670 |
+
if language is None:
|
671 |
+
if len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) == 0:
|
672 |
+
return {"status": 17, "detail": "没有加载任何示例数据"}
|
673 |
+
else:
|
674 |
+
# 随机选一个
|
675 |
+
rand_num = random.randint(
|
676 |
+
0,
|
677 |
+
len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) - 1,
|
678 |
+
)
|
679 |
+
# ZH
|
680 |
+
if rand_num < len(examples["ZH"]):
|
681 |
+
return {"status": 0, "Data": examples["ZH"][rand_num]}
|
682 |
+
# JP
|
683 |
+
if rand_num < len(examples["ZH"]) + len(examples["JP"]):
|
684 |
+
return {
|
685 |
+
"status": 0,
|
686 |
+
"Data": examples["JP"][rand_num - len(examples["ZH"])],
|
687 |
+
}
|
688 |
+
# EN
|
689 |
+
return {
|
690 |
+
"status": 0,
|
691 |
+
"Data": examples["EN"][
|
692 |
+
rand_num - len(examples["ZH"]) - len(examples["JP"])
|
693 |
+
],
|
694 |
+
}
|
695 |
+
|
696 |
+
else:
|
697 |
+
if len(examples[language]) == 0:
|
698 |
+
return {"status": 17, "detail": f"没有加载任何{language}数据"}
|
699 |
+
return {
|
700 |
+
"status": 0,
|
701 |
+
"Data": examples[language][
|
702 |
+
random.randint(0, len(examples[language]) - 1)
|
703 |
+
],
|
704 |
+
}
|
705 |
+
|
706 |
+
@app.get("/tools/get_audio")
|
707 |
+
def get_audio(request: Request, path: str = Query(..., description="本地音频路径")):
|
708 |
+
logger.info(
|
709 |
+
f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
|
710 |
+
)
|
711 |
+
if not os.path.isfile(path):
|
712 |
+
logger.error(f"/tools/get_audio 获取音频错误:指定音频{path}不存在")
|
713 |
+
return {"status": 18, "detail": "指定音频不存在"}
|
714 |
+
if not path.lower().endswith(".wav"):
|
715 |
+
logger.error(f"/tools/get_audio 获取音频错误:音频{path}非wav文件")
|
716 |
+
return {"status": 19, "detail": "非wav格式文件"}
|
717 |
+
return FileResponse(path=path)
|
718 |
+
|
719 |
+
logger.warning("本地服务,请勿将服务端口暴露于外网")
|
720 |
+
logger.info(f"api文档地址 http://127.0.0.1:{config.server_config.port}/docs")
|
721 |
+
if os.path.isdir(StaticDir):
|
722 |
+
webbrowser.open(f"http://127.0.0.1:{config.server_config.port}")
|
723 |
+
uvicorn.run(
|
724 |
+
app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
|
725 |
+
)
|
infer.py
CHANGED
@@ -5,19 +5,22 @@
|
|
5 |
2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
|
6 |
特殊版本说明:
|
7 |
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
|
8 |
-
2.
|
9 |
"""
|
10 |
import torch
|
11 |
import commons
|
12 |
from text import cleaned_text_to_sequence, get_bert
|
13 |
-
|
|
|
|
|
14 |
from text.cleaner import clean_text
|
15 |
import utils
|
16 |
-
import numpy as np
|
17 |
|
18 |
from models import SynthesizerTrn
|
19 |
from text.symbols import symbols
|
20 |
|
|
|
|
|
21 |
from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
|
22 |
from oldVersion.V210.text import symbols as V210symbols
|
23 |
from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
@@ -29,13 +32,14 @@ from oldVersion.V110.text import symbols as V110symbols
|
|
29 |
from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
30 |
from oldVersion.V101.text import symbols as V101symbols
|
31 |
|
32 |
-
from oldVersion import V111, V110, V101, V200, V210
|
33 |
|
34 |
# 当前版本信息
|
35 |
-
latest_version = "2.
|
36 |
|
37 |
# 版本兼容
|
38 |
SynthesizerTrnMap = {
|
|
|
39 |
"2.1": V210SynthesizerTrn,
|
40 |
"2.0.2-fix": V200SynthesizerTrn,
|
41 |
"2.0.1": V200SynthesizerTrn,
|
@@ -50,6 +54,7 @@ SynthesizerTrnMap = {
|
|
50 |
}
|
51 |
|
52 |
symbolsMap = {
|
|
|
53 |
"2.1": V210symbols,
|
54 |
"2.0.2-fix": V200symbols,
|
55 |
"2.0.1": V200symbols,
|
@@ -98,7 +103,8 @@ def get_net_g(model_path: str, version: str, device: str, hps):
|
|
98 |
return net_g
|
99 |
|
100 |
|
101 |
-
def get_text(text, language_str, hps, device):
|
|
|
102 |
# 在此处实现当前版本的get_text
|
103 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
104 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
@@ -110,21 +116,23 @@ def get_text(text, language_str, hps, device):
|
|
110 |
for i in range(len(word2ph)):
|
111 |
word2ph[i] = word2ph[i] * 2
|
112 |
word2ph[0] += 1
|
113 |
-
bert_ori = get_bert(
|
|
|
|
|
114 |
del word2ph
|
115 |
assert bert_ori.shape[-1] == len(phone), phone
|
116 |
|
117 |
if language_str == "ZH":
|
118 |
bert = bert_ori
|
119 |
-
ja_bert = torch.
|
120 |
-
en_bert = torch.
|
121 |
elif language_str == "JP":
|
122 |
-
bert = torch.
|
123 |
ja_bert = bert_ori
|
124 |
-
en_bert = torch.
|
125 |
elif language_str == "EN":
|
126 |
-
bert = torch.
|
127 |
-
ja_bert = torch.
|
128 |
en_bert = bert_ori
|
129 |
else:
|
130 |
raise ValueError("language_str should be ZH, JP or EN")
|
@@ -141,7 +149,7 @@ def get_text(text, language_str, hps, device):
|
|
141 |
|
142 |
def infer(
|
143 |
text,
|
144 |
-
emotion,
|
145 |
sdp_ratio,
|
146 |
noise_scale,
|
147 |
noise_scale_w,
|
@@ -154,8 +162,13 @@ def infer(
|
|
154 |
reference_audio=None,
|
155 |
skip_start=False,
|
156 |
skip_end=False,
|
|
|
|
|
157 |
):
|
158 |
# 2.2版本参数位置变了
|
|
|
|
|
|
|
159 |
# 2.1 参数新增 emotion reference_audio skip_start skip_end
|
160 |
inferMap_V3 = {
|
161 |
"2.1": V210.infer,
|
@@ -180,6 +193,25 @@ def infer(
|
|
180 |
version = hps.version if hasattr(hps, "version") else latest_version
|
181 |
# 非当前版本,根据版本号选择合适的infer
|
182 |
if version != latest_version:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
if version in inferMap_V3.keys():
|
184 |
return inferMap_V3[version](
|
185 |
text,
|
@@ -196,6 +228,8 @@ def infer(
|
|
196 |
emotion,
|
197 |
skip_start,
|
198 |
skip_end,
|
|
|
|
|
199 |
)
|
200 |
if version in inferMap_V2.keys():
|
201 |
return inferMap_V2[version](
|
@@ -224,14 +258,19 @@ def infer(
|
|
224 |
)
|
225 |
# 在此处实现当前版本的推理
|
226 |
# emo = get_emo_(reference_audio, emotion, sid)
|
227 |
-
if isinstance(reference_audio, np.ndarray):
|
228 |
-
|
229 |
-
else:
|
230 |
-
|
231 |
-
emo = torch.squeeze(emo, dim=1)
|
232 |
|
233 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
234 |
-
text,
|
|
|
|
|
|
|
|
|
|
|
235 |
)
|
236 |
if skip_start:
|
237 |
phones = phones[3:]
|
@@ -255,7 +294,7 @@ def infer(
|
|
255 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
256 |
en_bert = en_bert.to(device).unsqueeze(0)
|
257 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
258 |
-
emo = emo.to(device).unsqueeze(0)
|
259 |
del phones
|
260 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
261 |
audio = (
|
@@ -268,7 +307,6 @@ def infer(
|
|
268 |
bert,
|
269 |
ja_bert,
|
270 |
en_bert,
|
271 |
-
emo,
|
272 |
sdp_ratio=sdp_ratio,
|
273 |
noise_scale=noise_scale,
|
274 |
noise_scale_w=noise_scale_w,
|
@@ -278,7 +316,16 @@ def infer(
|
|
278 |
.float()
|
279 |
.numpy()
|
280 |
)
|
281 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
if torch.cuda.is_available():
|
283 |
torch.cuda.empty_cache()
|
284 |
return audio
|
@@ -302,14 +349,14 @@ def infer_multilang(
|
|
302 |
):
|
303 |
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
304 |
# emo = get_emo_(reference_audio, emotion, sid)
|
305 |
-
if isinstance(reference_audio, np.ndarray):
|
306 |
-
|
307 |
-
else:
|
308 |
-
|
309 |
-
emo = torch.squeeze(emo, dim=1)
|
310 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
311 |
-
|
312 |
-
|
313 |
(
|
314 |
temp_bert,
|
315 |
temp_ja_bert,
|
@@ -318,14 +365,14 @@ def infer_multilang(
|
|
318 |
temp_tones,
|
319 |
temp_lang_ids,
|
320 |
) = get_text(txt, lang, hps, device)
|
321 |
-
if
|
322 |
temp_bert = temp_bert[:, 3:]
|
323 |
temp_ja_bert = temp_ja_bert[:, 3:]
|
324 |
temp_en_bert = temp_en_bert[:, 3:]
|
325 |
temp_phones = temp_phones[3:]
|
326 |
temp_tones = temp_tones[3:]
|
327 |
temp_lang_ids = temp_lang_ids[3:]
|
328 |
-
if
|
329 |
temp_bert = temp_bert[:, :-2]
|
330 |
temp_ja_bert = temp_ja_bert[:, :-2]
|
331 |
temp_en_bert = temp_en_bert[:, :-2]
|
@@ -351,7 +398,7 @@ def infer_multilang(
|
|
351 |
bert = bert.to(device).unsqueeze(0)
|
352 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
353 |
en_bert = en_bert.to(device).unsqueeze(0)
|
354 |
-
emo = emo.to(device).unsqueeze(0)
|
355 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
356 |
del phones
|
357 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
@@ -365,7 +412,6 @@ def infer_multilang(
|
|
365 |
bert,
|
366 |
ja_bert,
|
367 |
en_bert,
|
368 |
-
emo,
|
369 |
sdp_ratio=sdp_ratio,
|
370 |
noise_scale=noise_scale,
|
371 |
noise_scale_w=noise_scale_w,
|
@@ -375,7 +421,16 @@ def infer_multilang(
|
|
375 |
.float()
|
376 |
.numpy()
|
377 |
)
|
378 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
if torch.cuda.is_available():
|
380 |
torch.cuda.empty_cache()
|
381 |
return audio
|
|
|
5 |
2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
|
6 |
特殊版本说明:
|
7 |
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
|
8 |
+
2.3:当前版本
|
9 |
"""
|
10 |
import torch
|
11 |
import commons
|
12 |
from text import cleaned_text_to_sequence, get_bert
|
13 |
+
|
14 |
+
# from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
15 |
+
from typing import Union
|
16 |
from text.cleaner import clean_text
|
17 |
import utils
|
|
|
18 |
|
19 |
from models import SynthesizerTrn
|
20 |
from text.symbols import symbols
|
21 |
|
22 |
+
from oldVersion.V220.models import SynthesizerTrn as V220SynthesizerTrn
|
23 |
+
from oldVersion.V220.text import symbols as V220symbols
|
24 |
from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
|
25 |
from oldVersion.V210.text import symbols as V210symbols
|
26 |
from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
|
|
32 |
from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
33 |
from oldVersion.V101.text import symbols as V101symbols
|
34 |
|
35 |
+
from oldVersion import V111, V110, V101, V200, V210, V220
|
36 |
|
37 |
# 当前版本信息
|
38 |
+
latest_version = "2.3"
|
39 |
|
40 |
# 版本兼容
|
41 |
SynthesizerTrnMap = {
|
42 |
+
"2.2": V220SynthesizerTrn,
|
43 |
"2.1": V210SynthesizerTrn,
|
44 |
"2.0.2-fix": V200SynthesizerTrn,
|
45 |
"2.0.1": V200SynthesizerTrn,
|
|
|
54 |
}
|
55 |
|
56 |
symbolsMap = {
|
57 |
+
"2.2": V220symbols,
|
58 |
"2.1": V210symbols,
|
59 |
"2.0.2-fix": V200symbols,
|
60 |
"2.0.1": V200symbols,
|
|
|
103 |
return net_g
|
104 |
|
105 |
|
106 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
107 |
+
style_text = None if style_text == "" else style_text
|
108 |
# 在此处实现当前版本的get_text
|
109 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
110 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
|
116 |
for i in range(len(word2ph)):
|
117 |
word2ph[i] = word2ph[i] * 2
|
118 |
word2ph[0] += 1
|
119 |
+
bert_ori = get_bert(
|
120 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
121 |
+
)
|
122 |
del word2ph
|
123 |
assert bert_ori.shape[-1] == len(phone), phone
|
124 |
|
125 |
if language_str == "ZH":
|
126 |
bert = bert_ori
|
127 |
+
ja_bert = torch.randn(1024, len(phone))
|
128 |
+
en_bert = torch.randn(1024, len(phone))
|
129 |
elif language_str == "JP":
|
130 |
+
bert = torch.randn(1024, len(phone))
|
131 |
ja_bert = bert_ori
|
132 |
+
en_bert = torch.randn(1024, len(phone))
|
133 |
elif language_str == "EN":
|
134 |
+
bert = torch.randn(1024, len(phone))
|
135 |
+
ja_bert = torch.randn(1024, len(phone))
|
136 |
en_bert = bert_ori
|
137 |
else:
|
138 |
raise ValueError("language_str should be ZH, JP or EN")
|
|
|
149 |
|
150 |
def infer(
|
151 |
text,
|
152 |
+
emotion: Union[int, str],
|
153 |
sdp_ratio,
|
154 |
noise_scale,
|
155 |
noise_scale_w,
|
|
|
162 |
reference_audio=None,
|
163 |
skip_start=False,
|
164 |
skip_end=False,
|
165 |
+
style_text=None,
|
166 |
+
style_weight=0.7,
|
167 |
):
|
168 |
# 2.2版本参数位置变了
|
169 |
+
inferMap_V4 = {
|
170 |
+
"2.2": V220.infer,
|
171 |
+
}
|
172 |
# 2.1 参数新增 emotion reference_audio skip_start skip_end
|
173 |
inferMap_V3 = {
|
174 |
"2.1": V210.infer,
|
|
|
193 |
version = hps.version if hasattr(hps, "version") else latest_version
|
194 |
# 非当前版本,根据版本号选择合适的infer
|
195 |
if version != latest_version:
|
196 |
+
if version in inferMap_V4.keys():
|
197 |
+
return inferMap_V4[version](
|
198 |
+
text,
|
199 |
+
emotion,
|
200 |
+
sdp_ratio,
|
201 |
+
noise_scale,
|
202 |
+
noise_scale_w,
|
203 |
+
length_scale,
|
204 |
+
sid,
|
205 |
+
language,
|
206 |
+
hps,
|
207 |
+
net_g,
|
208 |
+
device,
|
209 |
+
reference_audio,
|
210 |
+
skip_start,
|
211 |
+
skip_end,
|
212 |
+
style_text,
|
213 |
+
style_weight,
|
214 |
+
)
|
215 |
if version in inferMap_V3.keys():
|
216 |
return inferMap_V3[version](
|
217 |
text,
|
|
|
228 |
emotion,
|
229 |
skip_start,
|
230 |
skip_end,
|
231 |
+
style_text,
|
232 |
+
style_weight,
|
233 |
)
|
234 |
if version in inferMap_V2.keys():
|
235 |
return inferMap_V2[version](
|
|
|
258 |
)
|
259 |
# 在此处实现当前版本的推理
|
260 |
# emo = get_emo_(reference_audio, emotion, sid)
|
261 |
+
# if isinstance(reference_audio, np.ndarray):
|
262 |
+
# emo = get_clap_audio_feature(reference_audio, device)
|
263 |
+
# else:
|
264 |
+
# emo = get_clap_text_feature(emotion, device)
|
265 |
+
# emo = torch.squeeze(emo, dim=1)
|
266 |
|
267 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
268 |
+
text,
|
269 |
+
language,
|
270 |
+
hps,
|
271 |
+
device,
|
272 |
+
style_text=style_text,
|
273 |
+
style_weight=style_weight,
|
274 |
)
|
275 |
if skip_start:
|
276 |
phones = phones[3:]
|
|
|
294 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
295 |
en_bert = en_bert.to(device).unsqueeze(0)
|
296 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
297 |
+
# emo = emo.to(device).unsqueeze(0)
|
298 |
del phones
|
299 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
300 |
audio = (
|
|
|
307 |
bert,
|
308 |
ja_bert,
|
309 |
en_bert,
|
|
|
310 |
sdp_ratio=sdp_ratio,
|
311 |
noise_scale=noise_scale,
|
312 |
noise_scale_w=noise_scale_w,
|
|
|
316 |
.float()
|
317 |
.numpy()
|
318 |
)
|
319 |
+
del (
|
320 |
+
x_tst,
|
321 |
+
tones,
|
322 |
+
lang_ids,
|
323 |
+
bert,
|
324 |
+
x_tst_lengths,
|
325 |
+
speakers,
|
326 |
+
ja_bert,
|
327 |
+
en_bert,
|
328 |
+
) # , emo
|
329 |
if torch.cuda.is_available():
|
330 |
torch.cuda.empty_cache()
|
331 |
return audio
|
|
|
349 |
):
|
350 |
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
351 |
# emo = get_emo_(reference_audio, emotion, sid)
|
352 |
+
# if isinstance(reference_audio, np.ndarray):
|
353 |
+
# emo = get_clap_audio_feature(reference_audio, device)
|
354 |
+
# else:
|
355 |
+
# emo = get_clap_text_feature(emotion, device)
|
356 |
+
# emo = torch.squeeze(emo, dim=1)
|
357 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
358 |
+
_skip_start = (idx != 0) or (skip_start and idx == 0)
|
359 |
+
_skip_end = (idx != len(language) - 1) or skip_end
|
360 |
(
|
361 |
temp_bert,
|
362 |
temp_ja_bert,
|
|
|
365 |
temp_tones,
|
366 |
temp_lang_ids,
|
367 |
) = get_text(txt, lang, hps, device)
|
368 |
+
if _skip_start:
|
369 |
temp_bert = temp_bert[:, 3:]
|
370 |
temp_ja_bert = temp_ja_bert[:, 3:]
|
371 |
temp_en_bert = temp_en_bert[:, 3:]
|
372 |
temp_phones = temp_phones[3:]
|
373 |
temp_tones = temp_tones[3:]
|
374 |
temp_lang_ids = temp_lang_ids[3:]
|
375 |
+
if _skip_end:
|
376 |
temp_bert = temp_bert[:, :-2]
|
377 |
temp_ja_bert = temp_ja_bert[:, :-2]
|
378 |
temp_en_bert = temp_en_bert[:, :-2]
|
|
|
398 |
bert = bert.to(device).unsqueeze(0)
|
399 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
400 |
en_bert = en_bert.to(device).unsqueeze(0)
|
401 |
+
# emo = emo.to(device).unsqueeze(0)
|
402 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
403 |
del phones
|
404 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
|
|
412 |
bert,
|
413 |
ja_bert,
|
414 |
en_bert,
|
|
|
415 |
sdp_ratio=sdp_ratio,
|
416 |
noise_scale=noise_scale,
|
417 |
noise_scale_w=noise_scale_w,
|
|
|
421 |
.float()
|
422 |
.numpy()
|
423 |
)
|
424 |
+
del (
|
425 |
+
x_tst,
|
426 |
+
tones,
|
427 |
+
lang_ids,
|
428 |
+
bert,
|
429 |
+
x_tst_lengths,
|
430 |
+
speakers,
|
431 |
+
ja_bert,
|
432 |
+
en_bert,
|
433 |
+
) # , emo
|
434 |
if torch.cuda.is_available():
|
435 |
torch.cuda.empty_cache()
|
436 |
return audio
|
losses.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import torch
|
|
|
|
|
2 |
|
3 |
|
4 |
def feature_loss(fmap_r, fmap_g):
|
@@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
|
56 |
kl = torch.sum(kl * z_mask)
|
57 |
l = kl / torch.sum(z_mask)
|
58 |
return l
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torchaudio
|
3 |
+
from transformers import AutoModel
|
4 |
|
5 |
|
6 |
def feature_loss(fmap_r, fmap_g):
|
|
|
58 |
kl = torch.sum(kl * z_mask)
|
59 |
l = kl / torch.sum(z_mask)
|
60 |
return l
|
61 |
+
|
62 |
+
|
63 |
+
class WavLMLoss(torch.nn.Module):
|
64 |
+
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
65 |
+
super(WavLMLoss, self).__init__()
|
66 |
+
self.wavlm = AutoModel.from_pretrained(model)
|
67 |
+
self.wd = wd
|
68 |
+
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
69 |
+
self.wavlm.eval()
|
70 |
+
for param in self.wavlm.parameters():
|
71 |
+
param.requires_grad = False
|
72 |
+
|
73 |
+
def forward(self, wav, y_rec):
|
74 |
+
with torch.no_grad():
|
75 |
+
wav_16 = self.resample(wav)
|
76 |
+
wav_embeddings = self.wavlm(
|
77 |
+
input_values=wav_16, output_hidden_states=True
|
78 |
+
).hidden_states
|
79 |
+
y_rec_16 = self.resample(y_rec)
|
80 |
+
y_rec_embeddings = self.wavlm(
|
81 |
+
input_values=y_rec_16.squeeze(), output_hidden_states=True
|
82 |
+
).hidden_states
|
83 |
+
|
84 |
+
floss = 0
|
85 |
+
for er, eg in zip(wav_embeddings, y_rec_embeddings):
|
86 |
+
floss += torch.mean(torch.abs(er - eg))
|
87 |
+
|
88 |
+
return floss.mean()
|
89 |
+
|
90 |
+
def generator(self, y_rec):
|
91 |
+
y_rec_16 = self.resample(y_rec)
|
92 |
+
y_rec_embeddings = self.wavlm(
|
93 |
+
input_values=y_rec_16, output_hidden_states=True
|
94 |
+
).hidden_states
|
95 |
+
y_rec_embeddings = (
|
96 |
+
torch.stack(y_rec_embeddings, dim=1)
|
97 |
+
.transpose(-1, -2)
|
98 |
+
.flatten(start_dim=1, end_dim=2)
|
99 |
+
)
|
100 |
+
y_df_hat_g = self.wd(y_rec_embeddings)
|
101 |
+
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
|
102 |
+
|
103 |
+
return loss_gen
|
104 |
+
|
105 |
+
def discriminator(self, wav, y_rec):
|
106 |
+
with torch.no_grad():
|
107 |
+
wav_16 = self.resample(wav)
|
108 |
+
wav_embeddings = self.wavlm(
|
109 |
+
input_values=wav_16, output_hidden_states=True
|
110 |
+
).hidden_states
|
111 |
+
y_rec_16 = self.resample(y_rec)
|
112 |
+
y_rec_embeddings = self.wavlm(
|
113 |
+
input_values=y_rec_16, output_hidden_states=True
|
114 |
+
).hidden_states
|
115 |
+
|
116 |
+
y_embeddings = (
|
117 |
+
torch.stack(wav_embeddings, dim=1)
|
118 |
+
.transpose(-1, -2)
|
119 |
+
.flatten(start_dim=1, end_dim=2)
|
120 |
+
)
|
121 |
+
y_rec_embeddings = (
|
122 |
+
torch.stack(y_rec_embeddings, dim=1)
|
123 |
+
.transpose(-1, -2)
|
124 |
+
.flatten(start_dim=1, end_dim=2)
|
125 |
+
)
|
126 |
+
|
127 |
+
y_d_rs = self.wd(y_embeddings)
|
128 |
+
y_d_gs = self.wd(y_rec_embeddings)
|
129 |
+
|
130 |
+
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
131 |
+
|
132 |
+
r_loss = torch.mean((1 - y_df_hat_r) ** 2)
|
133 |
+
g_loss = torch.mean((y_df_hat_g) ** 2)
|
134 |
+
|
135 |
+
loss_disc_f = r_loss + g_loss
|
136 |
+
|
137 |
+
return loss_disc_f.mean()
|
138 |
+
|
139 |
+
def discriminator_forward(self, wav):
|
140 |
+
with torch.no_grad():
|
141 |
+
wav_16 = self.resample(wav)
|
142 |
+
wav_embeddings = self.wavlm(
|
143 |
+
input_values=wav_16, output_hidden_states=True
|
144 |
+
).hidden_states
|
145 |
+
y_embeddings = (
|
146 |
+
torch.stack(wav_embeddings, dim=1)
|
147 |
+
.transpose(-1, -2)
|
148 |
+
.flatten(start_dim=1, end_dim=2)
|
149 |
+
)
|
150 |
+
|
151 |
+
y_d_rs = self.wd(y_embeddings)
|
152 |
+
|
153 |
+
return y_d_rs
|
models.py
CHANGED
@@ -14,8 +14,6 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
14 |
from commons import init_weights, get_padding
|
15 |
from text import symbols, num_tones, num_languages
|
16 |
|
17 |
-
from vector_quantize_pytorch import VectorQuantize
|
18 |
-
|
19 |
|
20 |
class DurationDiscriminator(nn.Module): # vits2
|
21 |
def __init__(
|
@@ -40,33 +38,22 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
40 |
self.norm_2 = modules.LayerNorm(filter_channels)
|
41 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
42 |
|
43 |
-
self.
|
44 |
-
2 * filter_channels, filter_channels,
|
45 |
-
)
|
46 |
-
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
|
47 |
-
self.pre_out_conv_2 = nn.Conv1d(
|
48 |
-
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
49 |
)
|
50 |
-
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
51 |
|
52 |
if gin_channels != 0:
|
53 |
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
54 |
|
55 |
-
self.output_layer = nn.Sequential(
|
|
|
|
|
56 |
|
57 |
-
def forward_probability(self, x,
|
58 |
dur = self.dur_proj(dur)
|
59 |
x = torch.cat([x, dur], dim=1)
|
60 |
-
x = self.pre_out_conv_1(x * x_mask)
|
61 |
-
x = torch.relu(x)
|
62 |
-
x = self.pre_out_norm_1(x)
|
63 |
-
x = self.drop(x)
|
64 |
-
x = self.pre_out_conv_2(x * x_mask)
|
65 |
-
x = torch.relu(x)
|
66 |
-
x = self.pre_out_norm_2(x)
|
67 |
-
x = self.drop(x)
|
68 |
-
x = x * x_mask
|
69 |
x = x.transpose(1, 2)
|
|
|
70 |
output_prob = self.output_layer(x)
|
71 |
return output_prob
|
72 |
|
@@ -86,7 +73,7 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
86 |
|
87 |
output_probs = []
|
88 |
for dur in [dur_r, dur_hat]:
|
89 |
-
output_prob = self.forward_probability(x,
|
90 |
output_probs.append(output_prob)
|
91 |
|
92 |
return output_probs
|
@@ -354,7 +341,6 @@ class TextEncoder(nn.Module):
|
|
354 |
n_layers,
|
355 |
kernel_size,
|
356 |
p_dropout,
|
357 |
-
n_speakers,
|
358 |
gin_channels=0,
|
359 |
):
|
360 |
super().__init__()
|
@@ -376,31 +362,6 @@ class TextEncoder(nn.Module):
|
|
376 |
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
377 |
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
378 |
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
379 |
-
# self.emo_proj = nn.Linear(512, hidden_channels)
|
380 |
-
self.in_feature_net = nn.Sequential(
|
381 |
-
# input is assumed to an already normalized embedding
|
382 |
-
nn.Linear(512, 1028, bias=False),
|
383 |
-
nn.GELU(),
|
384 |
-
nn.LayerNorm(1028),
|
385 |
-
*[Block(1028, 512) for _ in range(1)],
|
386 |
-
nn.Linear(1028, 512, bias=False),
|
387 |
-
# normalize before passing to VQ?
|
388 |
-
# nn.GELU(),
|
389 |
-
# nn.LayerNorm(512),
|
390 |
-
)
|
391 |
-
self.emo_vq = VectorQuantize(
|
392 |
-
dim=512,
|
393 |
-
codebook_size=64,
|
394 |
-
codebook_dim=32,
|
395 |
-
commitment_weight=0.1,
|
396 |
-
decay=0.85,
|
397 |
-
heads=32,
|
398 |
-
kmeans_iters=20,
|
399 |
-
separate_codebook_per_head=True,
|
400 |
-
stochastic_sample_codes=True,
|
401 |
-
threshold_ema_dead_code=2,
|
402 |
-
)
|
403 |
-
self.out_feature_net = nn.Linear(512, hidden_channels)
|
404 |
|
405 |
self.encoder = attentions.Encoder(
|
406 |
hidden_channels,
|
@@ -413,18 +374,10 @@ class TextEncoder(nn.Module):
|
|
413 |
)
|
414 |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
415 |
|
416 |
-
def forward(
|
417 |
-
self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
|
418 |
-
):
|
419 |
-
sid = sid.cpu()
|
420 |
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
421 |
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
422 |
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
423 |
-
emo_emb = self.in_feature_net(emo)
|
424 |
-
emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
|
425 |
-
loss_commit = loss_commit.mean()
|
426 |
-
emo_emb = self.out_feature_net(emo_emb)
|
427 |
-
# emo_emb = self.emo_proj(emo.unsqueeze(1))
|
428 |
x = (
|
429 |
self.emb(x)
|
430 |
+ self.tone_emb(tone)
|
@@ -432,7 +385,6 @@ class TextEncoder(nn.Module):
|
|
432 |
+ bert_emb
|
433 |
+ ja_bert_emb
|
434 |
+ en_bert_emb
|
435 |
-
+ emo_emb
|
436 |
) * math.sqrt(
|
437 |
self.hidden_channels
|
438 |
) # [b, t, h]
|
@@ -445,7 +397,7 @@ class TextEncoder(nn.Module):
|
|
445 |
stats = self.proj(x) * x_mask
|
446 |
|
447 |
m, logs = torch.split(stats, self.out_channels, dim=1)
|
448 |
-
return x, m, logs, x_mask
|
449 |
|
450 |
|
451 |
class ResidualCouplingBlock(nn.Module):
|
@@ -748,6 +700,55 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|
748 |
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
749 |
|
750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
class ReferenceEncoder(nn.Module):
|
752 |
"""
|
753 |
inputs --- [N, Ty/r, n_mels*r] mels
|
@@ -878,7 +879,6 @@ class SynthesizerTrn(nn.Module):
|
|
878 |
n_layers,
|
879 |
kernel_size,
|
880 |
p_dropout,
|
881 |
-
self.n_speakers,
|
882 |
gin_channels=self.enc_gin_channels,
|
883 |
)
|
884 |
self.dec = Generator(
|
@@ -946,14 +946,13 @@ class SynthesizerTrn(nn.Module):
|
|
946 |
bert,
|
947 |
ja_bert,
|
948 |
en_bert,
|
949 |
-
emo=None,
|
950 |
):
|
951 |
if self.n_speakers > 0:
|
952 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
953 |
else:
|
954 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
955 |
-
x, m_p, logs_p, x_mask
|
956 |
-
x, x_lengths, tone, language, bert, ja_bert, en_bert,
|
957 |
)
|
958 |
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
959 |
z_p = self.flow(z, y_mask, g=g)
|
@@ -996,9 +995,11 @@ class SynthesizerTrn(nn.Module):
|
|
996 |
|
997 |
logw_ = torch.log(w + 1e-6) * x_mask
|
998 |
logw = self.dp(x, x_mask, g=g)
|
|
|
999 |
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
1000 |
x_mask
|
1001 |
) # for averaging
|
|
|
1002 |
|
1003 |
l_length = l_length_dp + l_length_sdp
|
1004 |
|
@@ -1018,9 +1019,8 @@ class SynthesizerTrn(nn.Module):
|
|
1018 |
x_mask,
|
1019 |
y_mask,
|
1020 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
1021 |
-
(x, logw, logw_),
|
1022 |
g,
|
1023 |
-
loss_commit,
|
1024 |
)
|
1025 |
|
1026 |
def infer(
|
@@ -1033,7 +1033,6 @@ class SynthesizerTrn(nn.Module):
|
|
1033 |
bert,
|
1034 |
ja_bert,
|
1035 |
en_bert,
|
1036 |
-
emo=None,
|
1037 |
noise_scale=0.667,
|
1038 |
length_scale=1,
|
1039 |
noise_scale_w=0.8,
|
@@ -1047,8 +1046,8 @@ class SynthesizerTrn(nn.Module):
|
|
1047 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1048 |
else:
|
1049 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1050 |
-
x, m_p, logs_p, x_mask
|
1051 |
-
x, x_lengths, tone, language, bert, ja_bert, en_bert,
|
1052 |
)
|
1053 |
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
1054 |
sdp_ratio
|
|
|
14 |
from commons import init_weights, get_padding
|
15 |
from text import symbols, num_tones, num_languages
|
16 |
|
|
|
|
|
17 |
|
18 |
class DurationDiscriminator(nn.Module): # vits2
|
19 |
def __init__(
|
|
|
38 |
self.norm_2 = modules.LayerNorm(filter_channels)
|
39 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
40 |
|
41 |
+
self.LSTM = nn.LSTM(
|
42 |
+
2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
|
|
|
|
|
|
|
|
|
43 |
)
|
|
|
44 |
|
45 |
if gin_channels != 0:
|
46 |
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
47 |
|
48 |
+
self.output_layer = nn.Sequential(
|
49 |
+
nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
|
50 |
+
)
|
51 |
|
52 |
+
def forward_probability(self, x, dur):
|
53 |
dur = self.dur_proj(dur)
|
54 |
x = torch.cat([x, dur], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
x = x.transpose(1, 2)
|
56 |
+
x, _ = self.LSTM(x)
|
57 |
output_prob = self.output_layer(x)
|
58 |
return output_prob
|
59 |
|
|
|
73 |
|
74 |
output_probs = []
|
75 |
for dur in [dur_r, dur_hat]:
|
76 |
+
output_prob = self.forward_probability(x, dur)
|
77 |
output_probs.append(output_prob)
|
78 |
|
79 |
return output_probs
|
|
|
341 |
n_layers,
|
342 |
kernel_size,
|
343 |
p_dropout,
|
|
|
344 |
gin_channels=0,
|
345 |
):
|
346 |
super().__init__()
|
|
|
362 |
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
363 |
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
364 |
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
self.encoder = attentions.Encoder(
|
367 |
hidden_channels,
|
|
|
374 |
)
|
375 |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
376 |
|
377 |
+
def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
|
|
|
|
|
|
|
378 |
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
379 |
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
380 |
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
381 |
x = (
|
382 |
self.emb(x)
|
383 |
+ self.tone_emb(tone)
|
|
|
385 |
+ bert_emb
|
386 |
+ ja_bert_emb
|
387 |
+ en_bert_emb
|
|
|
388 |
) * math.sqrt(
|
389 |
self.hidden_channels
|
390 |
) # [b, t, h]
|
|
|
397 |
stats = self.proj(x) * x_mask
|
398 |
|
399 |
m, logs = torch.split(stats, self.out_channels, dim=1)
|
400 |
+
return x, m, logs, x_mask
|
401 |
|
402 |
|
403 |
class ResidualCouplingBlock(nn.Module):
|
|
|
700 |
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
701 |
|
702 |
|
703 |
+
class WavLMDiscriminator(nn.Module):
|
704 |
+
"""docstring for Discriminator."""
|
705 |
+
|
706 |
+
def __init__(
|
707 |
+
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
708 |
+
):
|
709 |
+
super(WavLMDiscriminator, self).__init__()
|
710 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
711 |
+
self.pre = norm_f(
|
712 |
+
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
713 |
+
)
|
714 |
+
|
715 |
+
self.convs = nn.ModuleList(
|
716 |
+
[
|
717 |
+
norm_f(
|
718 |
+
nn.Conv1d(
|
719 |
+
initial_channel, initial_channel * 2, kernel_size=5, padding=2
|
720 |
+
)
|
721 |
+
),
|
722 |
+
norm_f(
|
723 |
+
nn.Conv1d(
|
724 |
+
initial_channel * 2,
|
725 |
+
initial_channel * 4,
|
726 |
+
kernel_size=5,
|
727 |
+
padding=2,
|
728 |
+
)
|
729 |
+
),
|
730 |
+
norm_f(
|
731 |
+
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
|
732 |
+
),
|
733 |
+
]
|
734 |
+
)
|
735 |
+
|
736 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
737 |
+
|
738 |
+
def forward(self, x):
|
739 |
+
x = self.pre(x)
|
740 |
+
|
741 |
+
fmap = []
|
742 |
+
for l in self.convs:
|
743 |
+
x = l(x)
|
744 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
745 |
+
fmap.append(x)
|
746 |
+
x = self.conv_post(x)
|
747 |
+
x = torch.flatten(x, 1, -1)
|
748 |
+
|
749 |
+
return x
|
750 |
+
|
751 |
+
|
752 |
class ReferenceEncoder(nn.Module):
|
753 |
"""
|
754 |
inputs --- [N, Ty/r, n_mels*r] mels
|
|
|
879 |
n_layers,
|
880 |
kernel_size,
|
881 |
p_dropout,
|
|
|
882 |
gin_channels=self.enc_gin_channels,
|
883 |
)
|
884 |
self.dec = Generator(
|
|
|
946 |
bert,
|
947 |
ja_bert,
|
948 |
en_bert,
|
|
|
949 |
):
|
950 |
if self.n_speakers > 0:
|
951 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
952 |
else:
|
953 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
954 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
955 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
|
956 |
)
|
957 |
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
958 |
z_p = self.flow(z, y_mask, g=g)
|
|
|
995 |
|
996 |
logw_ = torch.log(w + 1e-6) * x_mask
|
997 |
logw = self.dp(x, x_mask, g=g)
|
998 |
+
logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
|
999 |
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
1000 |
x_mask
|
1001 |
) # for averaging
|
1002 |
+
l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
1003 |
|
1004 |
l_length = l_length_dp + l_length_sdp
|
1005 |
|
|
|
1019 |
x_mask,
|
1020 |
y_mask,
|
1021 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
1022 |
+
(x, logw, logw_, logw_sdp),
|
1023 |
g,
|
|
|
1024 |
)
|
1025 |
|
1026 |
def infer(
|
|
|
1033 |
bert,
|
1034 |
ja_bert,
|
1035 |
en_bert,
|
|
|
1036 |
noise_scale=0.667,
|
1037 |
length_scale=1,
|
1038 |
noise_scale_w=0.8,
|
|
|
1046 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1047 |
else:
|
1048 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1049 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
1050 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
|
1051 |
)
|
1052 |
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
1053 |
sdp_ratio
|
modules.py
CHANGED
@@ -83,7 +83,7 @@ class ConvReluNorm(nn.Module):
|
|
83 |
|
84 |
class DDSConv(nn.Module):
|
85 |
"""
|
86 |
-
|
87 |
"""
|
88 |
|
89 |
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
|
|
83 |
|
84 |
class DDSConv(nn.Module):
|
85 |
"""
|
86 |
+
Dilated and Depth-Separable Convolution
|
87 |
"""
|
88 |
|
89 |
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
onnx_infer.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from onnx_modules.V220_OnnxInference import OnnxInferenceSession
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
Session = OnnxInferenceSession(
|
5 |
+
{
|
6 |
+
"enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
|
7 |
+
"emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
|
8 |
+
"dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
|
9 |
+
"sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
|
10 |
+
"flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
|
11 |
+
"dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx",
|
12 |
+
},
|
13 |
+
Providers=["CPUExecutionProvider"],
|
14 |
+
)
|
15 |
+
|
16 |
+
# 这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可
|
17 |
+
x = np.array(
|
18 |
+
[
|
19 |
+
0,
|
20 |
+
97,
|
21 |
+
0,
|
22 |
+
8,
|
23 |
+
0,
|
24 |
+
78,
|
25 |
+
0,
|
26 |
+
8,
|
27 |
+
0,
|
28 |
+
76,
|
29 |
+
0,
|
30 |
+
37,
|
31 |
+
0,
|
32 |
+
40,
|
33 |
+
0,
|
34 |
+
97,
|
35 |
+
0,
|
36 |
+
8,
|
37 |
+
0,
|
38 |
+
23,
|
39 |
+
0,
|
40 |
+
8,
|
41 |
+
0,
|
42 |
+
74,
|
43 |
+
0,
|
44 |
+
26,
|
45 |
+
0,
|
46 |
+
104,
|
47 |
+
0,
|
48 |
+
]
|
49 |
+
)
|
50 |
+
tone = np.zeros_like(x)
|
51 |
+
language = np.zeros_like(x)
|
52 |
+
sid = np.array([0])
|
53 |
+
bert = np.random.randn(x.shape[0], 1024)
|
54 |
+
ja_bert = np.random.randn(x.shape[0], 1024)
|
55 |
+
en_bert = np.random.randn(x.shape[0], 1024)
|
56 |
+
emo = np.random.randn(512, 1)
|
57 |
+
|
58 |
+
audio = Session(x, tone, language, bert, ja_bert, en_bert, emo, sid)
|
59 |
+
|
60 |
+
print(audio)
|
re_matching.py
CHANGED
@@ -44,7 +44,6 @@ def text_matching(text: str) -> list:
|
|
44 |
result = []
|
45 |
for speaker, dialogue in matches:
|
46 |
result.append(extract_language_and_text_updated(speaker, dialogue))
|
47 |
-
print(result)
|
48 |
return result
|
49 |
|
50 |
|
|
|
44 |
result = []
|
45 |
for speaker, dialogue in matches:
|
46 |
result.append(extract_language_and_text_updated(speaker, dialogue))
|
|
|
47 |
return result
|
48 |
|
49 |
|
resample.py
CHANGED
@@ -10,11 +10,11 @@ from config import config
|
|
10 |
|
11 |
|
12 |
def process(item):
|
13 |
-
wav_name, args = item
|
14 |
-
wav_path = os.path.join(args.in_dir, wav_name)
|
15 |
if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
|
16 |
wav, sr = librosa.load(wav_path, sr=args.sr)
|
17 |
-
soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
@@ -54,11 +54,15 @@ if __name__ == "__main__":
|
|
54 |
tasks = []
|
55 |
|
56 |
for dirpath, _, filenames in os.walk(args.in_dir):
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
for filename in filenames:
|
60 |
if filename.lower().endswith(".wav"):
|
61 |
-
|
|
|
62 |
|
63 |
for _ in tqdm(
|
64 |
pool.imap_unordered(process, tasks),
|
|
|
10 |
|
11 |
|
12 |
def process(item):
|
13 |
+
spkdir, wav_name, args = item
|
14 |
+
wav_path = os.path.join(args.in_dir, spkdir, wav_name)
|
15 |
if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
|
16 |
wav, sr = librosa.load(wav_path, sr=args.sr)
|
17 |
+
soundfile.write(os.path.join(args.out_dir, spkdir, wav_name), wav, sr)
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
|
|
54 |
tasks = []
|
55 |
|
56 |
for dirpath, _, filenames in os.walk(args.in_dir):
|
57 |
+
# 子级目录
|
58 |
+
spk_dir = os.path.relpath(dirpath, args.in_dir)
|
59 |
+
spk_dir_out = os.path.join(args.out_dir, spk_dir)
|
60 |
+
if not os.path.isdir(spk_dir_out):
|
61 |
+
os.makedirs(spk_dir_out, exist_ok=True)
|
62 |
for filename in filenames:
|
63 |
if filename.lower().endswith(".wav"):
|
64 |
+
twople = (spk_dir, filename, args)
|
65 |
+
tasks.append(twople)
|
66 |
|
67 |
for _ in tqdm(
|
68 |
pool.imap_unordered(process, tasks),
|
resample_legacy.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import librosa
|
4 |
+
from multiprocessing import Pool, cpu_count
|
5 |
+
|
6 |
+
import soundfile
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from config import config
|
10 |
+
|
11 |
+
|
12 |
+
def process(item):
|
13 |
+
wav_name, args = item
|
14 |
+
wav_path = os.path.join(args.in_dir, wav_name)
|
15 |
+
if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
|
16 |
+
wav, sr = librosa.load(wav_path, sr=args.sr)
|
17 |
+
soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"--sr",
|
24 |
+
type=int,
|
25 |
+
default=config.resample_config.sampling_rate,
|
26 |
+
help="sampling rate",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--in_dir",
|
30 |
+
type=str,
|
31 |
+
default=config.resample_config.in_dir,
|
32 |
+
help="path to source dir",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--out_dir",
|
36 |
+
type=str,
|
37 |
+
default=config.resample_config.out_dir,
|
38 |
+
help="path to target dir",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--processes",
|
42 |
+
type=int,
|
43 |
+
default=0,
|
44 |
+
help="cpu_processes",
|
45 |
+
)
|
46 |
+
args, _ = parser.parse_known_args()
|
47 |
+
# autodl 无卡模式会识别出46个cpu
|
48 |
+
if args.processes == 0:
|
49 |
+
processes = cpu_count() - 2 if cpu_count() > 4 else 1
|
50 |
+
else:
|
51 |
+
processes = args.processes
|
52 |
+
pool = Pool(processes=processes)
|
53 |
+
|
54 |
+
tasks = []
|
55 |
+
|
56 |
+
for dirpath, _, filenames in os.walk(args.in_dir):
|
57 |
+
if not os.path.isdir(args.out_dir):
|
58 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
59 |
+
for filename in filenames:
|
60 |
+
if filename.lower().endswith(".wav"):
|
61 |
+
tasks.append((filename, args))
|
62 |
+
|
63 |
+
for _ in tqdm(
|
64 |
+
pool.imap_unordered(process, tasks),
|
65 |
+
):
|
66 |
+
pass
|
67 |
+
|
68 |
+
pool.close()
|
69 |
+
pool.join()
|
70 |
+
|
71 |
+
print("音频重采样完毕!")
|
server.py
CHANGED
@@ -3,10 +3,8 @@ import os
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import logging
|
6 |
-
import re_matching
|
7 |
import uuid
|
8 |
-
|
9 |
-
from flask_cors import CORS
|
10 |
|
11 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
12 |
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
@@ -18,6 +16,8 @@ logging.basicConfig(
|
|
18 |
)
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
|
|
|
|
21 |
import librosa
|
22 |
import numpy as np
|
23 |
import torch
|
@@ -26,24 +26,44 @@ from torch.utils.data import Dataset
|
|
26 |
from torch.utils.data import DataLoader, Dataset
|
27 |
from tqdm import tqdm
|
28 |
|
|
|
|
|
29 |
import utils
|
30 |
from config import config
|
31 |
-
|
32 |
import torch
|
33 |
import commons
|
34 |
from text import cleaned_text_to_sequence, get_bert
|
35 |
-
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
36 |
-
|
37 |
from text.cleaner import clean_text
|
38 |
import utils
|
39 |
|
40 |
from models import SynthesizerTrn
|
41 |
from text.symbols import symbols
|
42 |
import sys
|
|
|
43 |
|
44 |
-
|
|
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
net_g = None
|
|
|
47 |
device = (
|
48 |
"cuda:0"
|
49 |
if torch.cuda.is_available()
|
@@ -54,7 +74,375 @@ device = (
|
|
54 |
)
|
55 |
)
|
56 |
|
57 |
-
#device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def get_net_g(model_path: str, device: str, hps):
|
60 |
net_g = SynthesizerTrn(
|
@@ -68,11 +456,11 @@ def get_net_g(model_path: str, device: str, hps):
|
|
68 |
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
69 |
return net_g
|
70 |
|
71 |
-
|
72 |
-
|
73 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
74 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
75 |
-
|
76 |
if hps.data.add_blank:
|
77 |
phone = commons.intersperse(phone, 0)
|
78 |
tone = commons.intersperse(tone, 0)
|
@@ -80,18 +468,24 @@ def get_text(text, language_str, hps, device):
|
|
80 |
for i in range(len(word2ph)):
|
81 |
word2ph[i] = word2ph[i] * 2
|
82 |
word2ph[0] += 1
|
83 |
-
bert_ori = get_bert(
|
|
|
|
|
84 |
del word2ph
|
85 |
assert bert_ori.shape[-1] == len(phone), phone
|
86 |
|
87 |
if language_str == "ZH":
|
88 |
bert = bert_ori
|
89 |
-
ja_bert = torch.
|
90 |
-
en_bert = torch.
|
91 |
elif language_str == "JP":
|
92 |
-
bert = torch.
|
93 |
ja_bert = bert_ori
|
94 |
-
en_bert = torch.
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
raise ValueError("language_str should be ZH, JP or EN")
|
97 |
|
@@ -111,19 +505,47 @@ def infer(
|
|
111 |
noise_scale_w,
|
112 |
length_scale,
|
113 |
sid,
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
):
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
125 |
-
text,
|
|
|
|
|
|
|
|
|
|
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
with torch.no_grad():
|
128 |
x_tst = phones.to(device).unsqueeze(0)
|
129 |
tones = tones.to(device).unsqueeze(0)
|
@@ -132,7 +554,7 @@ def infer(
|
|
132 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
133 |
en_bert = en_bert.to(device).unsqueeze(0)
|
134 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
135 |
-
emo = emo.to(device).unsqueeze(0)
|
136 |
del phones
|
137 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
138 |
audio = (
|
@@ -145,7 +567,6 @@ def infer(
|
|
145 |
bert,
|
146 |
ja_bert,
|
147 |
en_bert,
|
148 |
-
emo,
|
149 |
sdp_ratio=sdp_ratio,
|
150 |
noise_scale=noise_scale,
|
151 |
noise_scale_w=noise_scale_w,
|
@@ -155,79 +576,292 @@ def infer(
|
|
155 |
.float()
|
156 |
.numpy()
|
157 |
)
|
158 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
if torch.cuda.is_available():
|
160 |
torch.cuda.empty_cache()
|
161 |
-
|
162 |
-
|
163 |
-
return unique_filename
|
164 |
-
|
165 |
-
def is_japanese(string):
|
166 |
-
for ch in string:
|
167 |
-
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
|
168 |
-
return True
|
169 |
-
return False
|
170 |
|
171 |
def loadmodel(model):
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
def tts():
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
with open(unique_filename ,'rb') as bit:
|
232 |
wav_bytes = bit.read()
|
233 |
os.remove(unique_filename)
|
@@ -238,17 +872,13 @@ def tts():
|
|
238 |
|
239 |
|
240 |
if __name__ == "__main__":
|
241 |
-
|
242 |
-
|
243 |
-
for dirpath, dirnames, filenames in os.walk("Data/BangDreamV22/models/"):
|
244 |
-
for filename in filenames:
|
245 |
-
modelPaths.append(os.path.join(dirpath, filename))
|
246 |
-
hps = utils.get_hparams_from_file('Data/BangDreamV22/configs/config.json')
|
247 |
net_g = get_net_g(
|
248 |
model_path=modelPaths[-1], device=device, hps=hps
|
249 |
)
|
250 |
speaker_ids = hps.data.spk2id
|
251 |
speakers = list(speaker_ids.keys())
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import logging
|
|
|
6 |
import uuid
|
7 |
+
import re_matching
|
|
|
8 |
|
9 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
10 |
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
|
|
16 |
)
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
+
import shutil
|
20 |
+
from scipy.io.wavfile import write
|
21 |
import librosa
|
22 |
import numpy as np
|
23 |
import torch
|
|
|
26 |
from torch.utils.data import DataLoader, Dataset
|
27 |
from tqdm import tqdm
|
28 |
|
29 |
+
import gradio as gr
|
30 |
+
|
31 |
import utils
|
32 |
from config import config
|
33 |
+
|
34 |
import torch
|
35 |
import commons
|
36 |
from text import cleaned_text_to_sequence, get_bert
|
|
|
|
|
37 |
from text.cleaner import clean_text
|
38 |
import utils
|
39 |
|
40 |
from models import SynthesizerTrn
|
41 |
from text.symbols import symbols
|
42 |
import sys
|
43 |
+
import re
|
44 |
|
45 |
+
import random
|
46 |
+
import hashlib
|
47 |
|
48 |
+
from fugashi import Tagger
|
49 |
+
import jaconv
|
50 |
+
import unidic
|
51 |
+
import subprocess
|
52 |
+
|
53 |
+
import requests
|
54 |
+
|
55 |
+
from ebooklib import epub
|
56 |
+
import PyPDF2
|
57 |
+
from PyPDF2 import PdfReader
|
58 |
+
from bs4 import BeautifulSoup
|
59 |
+
import jieba
|
60 |
+
import romajitable
|
61 |
+
|
62 |
+
from flask import Flask, request, jsonify, render_template_string, send_file
|
63 |
+
from flask_cors import CORS
|
64 |
+
from scipy.io.wavfile import write
|
65 |
net_g = None
|
66 |
+
|
67 |
device = (
|
68 |
"cuda:0"
|
69 |
if torch.cuda.is_available()
|
|
|
74 |
)
|
75 |
)
|
76 |
|
77 |
+
#device = "cpu"
|
78 |
+
BandList = {
|
79 |
+
"PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
|
80 |
+
"Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
|
81 |
+
"HelloHappyWorld":["こころ","美咲","薫","花音","はぐみ"],
|
82 |
+
"PastelPalettes":["彩","日菜","千聖","イヴ","麻弥"],
|
83 |
+
"Roselia":["友希那","紗夜","リサ","燐子","あこ"],
|
84 |
+
"RaiseASuilen":["レイヤ","ロック","ますき","チュチュ","パレオ"],
|
85 |
+
"Morfonica":["ましろ","瑠唯","つくし","七深","透子"],
|
86 |
+
"MyGo":["燈","愛音","そよ","立希","楽奈"],
|
87 |
+
"AveMujica":["祥子","睦","海鈴","にゃむ","初華"],
|
88 |
+
"圣翔音乐学园":["華戀","光","香子","雙葉","真晝","純那","克洛迪娜","真矢","奈奈"],
|
89 |
+
"凛明馆女子学校":["珠緒","壘","文","悠悠子","一愛"],
|
90 |
+
"弗隆提亚艺术学校":["艾露","艾露露","菈樂菲","司","靜羽"],
|
91 |
+
"西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
|
92 |
+
}
|
93 |
+
|
94 |
+
webBase = 'https://mahiruoshi-bangdream-bert-vits2.hf.space/'
|
95 |
+
|
96 |
+
port = 8080
|
97 |
+
|
98 |
+
languages = [ "Auto", "ZH", "JP"]
|
99 |
+
modelPaths = []
|
100 |
+
modes = ['pyopenjtalk-V2.3-Katakana','fugashi-V2.3-Katakana','pyopenjtalk-V2.3-Katakana-Katakana','fugashi-V2.3-Katakana-Katakana','onnx-V2.3']
|
101 |
+
sentence_modes = ['sentence','paragraph']
|
102 |
+
for dirpath, dirnames, filenames in os.walk('Data/BangDream/models/'):
|
103 |
+
for filename in filenames:
|
104 |
+
modelPaths.append(os.path.join(dirpath, filename))
|
105 |
+
hps = utils.get_hparams_from_file('Data/BangDream/config.json')
|
106 |
+
|
107 |
+
def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
|
108 |
+
"""
|
109 |
+
:param Sentence: 待翻译语句
|
110 |
+
:param from_Language: 待翻译语句语言
|
111 |
+
:param to_Language: 目标语言
|
112 |
+
:return: 翻译后语句 出错时返回None
|
113 |
+
|
114 |
+
常见语言代码:中文 zh 英语 en 日语 jp
|
115 |
+
"""
|
116 |
+
appid = "20231117001883321"
|
117 |
+
key = "lMQbvZHeJveDceLof2wf"
|
118 |
+
if appid == "" or key == "":
|
119 |
+
return "请开发者在config.yml中配置app_key与secret_key"
|
120 |
+
url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
|
121 |
+
texts = Sentence.splitlines()
|
122 |
+
outTexts = []
|
123 |
+
for t in texts:
|
124 |
+
if t != "":
|
125 |
+
# 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
|
126 |
+
salt = str(random.randint(1, 100000))
|
127 |
+
signString = appid + t + salt + key
|
128 |
+
hs = hashlib.md5()
|
129 |
+
hs.update(signString.encode("utf-8"))
|
130 |
+
signString = hs.hexdigest()
|
131 |
+
if from_Language == "":
|
132 |
+
from_Language = "auto"
|
133 |
+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
134 |
+
payload = {
|
135 |
+
"q": t,
|
136 |
+
"from": from_Language,
|
137 |
+
"to": to_Language,
|
138 |
+
"appid": appid,
|
139 |
+
"salt": salt,
|
140 |
+
"sign": signString,
|
141 |
+
}
|
142 |
+
# 发送请求
|
143 |
+
try:
|
144 |
+
response = requests.post(
|
145 |
+
url=url, data=payload, headers=headers, timeout=3
|
146 |
+
)
|
147 |
+
response = response.json()
|
148 |
+
if "trans_result" in response.keys():
|
149 |
+
result = response["trans_result"][0]
|
150 |
+
if "dst" in result.keys():
|
151 |
+
dst = result["dst"]
|
152 |
+
outTexts.append(dst)
|
153 |
+
except Exception:
|
154 |
+
return Sentence
|
155 |
+
else:
|
156 |
+
outTexts.append(t)
|
157 |
+
return "\n".join(outTexts)
|
158 |
+
|
159 |
+
#文本清洗工具
|
160 |
+
def is_japanese(string):
|
161 |
+
for ch in string:
|
162 |
+
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
|
163 |
+
return True
|
164 |
+
return False
|
165 |
+
|
166 |
+
def is_chinese(string):
|
167 |
+
for ch in string:
|
168 |
+
if '\u4e00' <= ch <= '\u9fff':
|
169 |
+
return True
|
170 |
+
return False
|
171 |
+
|
172 |
+
def is_single_language(sentence):
|
173 |
+
# 检查句子是否为单一语言
|
174 |
+
contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
|
175 |
+
contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
|
176 |
+
contains_english = re.search(r'[a-zA-Z]', sentence) is not None
|
177 |
+
language_count = sum([contains_chinese, contains_japanese, contains_english])
|
178 |
+
return language_count == 1
|
179 |
+
|
180 |
+
def merge_scattered_parts(sentences):
|
181 |
+
"""合并零散的部分到相邻的句子中,并确保单一语言性"""
|
182 |
+
merged_sentences = []
|
183 |
+
buffer_sentence = ""
|
184 |
+
|
185 |
+
for sentence in sentences:
|
186 |
+
# 检查是否是单一语言或者太短(可能是标点或单个词)
|
187 |
+
if is_single_language(sentence) and len(sentence) > 1:
|
188 |
+
# 如果缓冲区有内容,先将缓冲区的内容添加到列表
|
189 |
+
if buffer_sentence:
|
190 |
+
merged_sentences.append(buffer_sentence)
|
191 |
+
buffer_sentence = ""
|
192 |
+
merged_sentences.append(sentence)
|
193 |
+
else:
|
194 |
+
# 如果是零散的部分,将其添加到缓冲区
|
195 |
+
buffer_sentence += sentence
|
196 |
+
|
197 |
+
# 确保最后的缓冲区内容被添加
|
198 |
+
if buffer_sentence:
|
199 |
+
merged_sentences.append(buffer_sentence)
|
200 |
+
|
201 |
+
return merged_sentences
|
202 |
+
|
203 |
+
def is_only_punctuation(s):
|
204 |
+
"""检查字符串是否只包含标点符号"""
|
205 |
+
# 此处列出中文、日文、英文常见标点符号
|
206 |
+
punctuation_pattern = re.compile(r'^[\s。*;,:“”()、!?《》\u3000\.,;:"\'?!()]+$')
|
207 |
+
return punctuation_pattern.match(s) is not None
|
208 |
+
|
209 |
+
def split_mixed_language(sentence):
|
210 |
+
# 分割混合语言句子
|
211 |
+
# 逐字符检查,分割不同语言部分
|
212 |
+
sub_sentences = []
|
213 |
+
current_language = None
|
214 |
+
current_part = ""
|
215 |
+
|
216 |
+
for char in sentence:
|
217 |
+
if re.match(r'[\u4e00-\u9fff]', char): # Chinese character
|
218 |
+
if current_language != 'chinese':
|
219 |
+
if current_part:
|
220 |
+
sub_sentences.append(current_part)
|
221 |
+
current_part = char
|
222 |
+
current_language = 'chinese'
|
223 |
+
else:
|
224 |
+
current_part += char
|
225 |
+
elif re.match(r'[\u3040-\u30ff\u31f0-\u31ff]', char): # Japanese character
|
226 |
+
if current_language != 'japanese':
|
227 |
+
if current_part:
|
228 |
+
sub_sentences.append(current_part)
|
229 |
+
current_part = char
|
230 |
+
current_language = 'japanese'
|
231 |
+
else:
|
232 |
+
current_part += char
|
233 |
+
elif re.match(r'[a-zA-Z]', char): # English character
|
234 |
+
if current_language != 'english':
|
235 |
+
if current_part:
|
236 |
+
sub_sentences.append(current_part)
|
237 |
+
current_part = char
|
238 |
+
current_language = 'english'
|
239 |
+
else:
|
240 |
+
current_part += char
|
241 |
+
else:
|
242 |
+
current_part += char # For punctuation and other characters
|
243 |
+
|
244 |
+
if current_part:
|
245 |
+
sub_sentences.append(current_part)
|
246 |
+
|
247 |
+
return sub_sentences
|
248 |
+
|
249 |
+
def replace_quotes(text):
|
250 |
+
# 替换中文、日文引号为英文引号
|
251 |
+
text = re.sub(r'[“”‘’『』「」()()]', '"', text)
|
252 |
+
return text
|
253 |
+
|
254 |
+
def remove_numeric_annotations(text):
|
255 |
+
# 定义用于匹配数字注释的正则表达式
|
256 |
+
# 包括 “”、【】和〔〕包裹的数字
|
257 |
+
pattern = r'“\d+”|【\d+】|〔\d+〕'
|
258 |
+
# 使用正则表达式替换掉这些注释
|
259 |
+
cleaned_text = re.sub(pattern, '', text)
|
260 |
+
return cleaned_text
|
261 |
+
|
262 |
+
def merge_adjacent_japanese(sentences):
|
263 |
+
"""合并相邻且都只包含日语的句子"""
|
264 |
+
merged_sentences = []
|
265 |
+
i = 0
|
266 |
+
while i < len(sentences):
|
267 |
+
current_sentence = sentences[i]
|
268 |
+
if i + 1 < len(sentences) and is_japanese(current_sentence) and is_japanese(sentences[i + 1]):
|
269 |
+
# 当前句子和下一句都是日语,合并它们
|
270 |
+
while i + 1 < len(sentences) and is_japanese(sentences[i + 1]):
|
271 |
+
current_sentence += sentences[i + 1]
|
272 |
+
i += 1
|
273 |
+
merged_sentences.append(current_sentence)
|
274 |
+
i += 1
|
275 |
+
return merged_sentences
|
276 |
+
|
277 |
+
def extrac(text):
|
278 |
+
text = replace_quotes(remove_numeric_annotations(text)) # 替换引号
|
279 |
+
text = re.sub("<[^>]*>", "", text) # 移除 HTML 标签
|
280 |
+
# 使用换行符和标点符号进行初步分割
|
281 |
+
preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
|
282 |
+
final_sentences = []
|
283 |
+
|
284 |
+
preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
|
285 |
+
|
286 |
+
for piece in preliminary_sentences:
|
287 |
+
if is_single_language(piece):
|
288 |
+
final_sentences.append(piece)
|
289 |
+
else:
|
290 |
+
sub_sentences = split_mixed_language(piece)
|
291 |
+
final_sentences.extend(sub_sentences)
|
292 |
+
|
293 |
+
# 处理长句子,使用jieba进行分词
|
294 |
+
split_sentences = []
|
295 |
+
for sentence in final_sentences:
|
296 |
+
split_sentences.extend(split_long_sentences(sentence))
|
297 |
+
|
298 |
+
# 合并相邻的日语句子
|
299 |
+
merged_japanese_sentences = merge_adjacent_japanese(split_sentences)
|
300 |
+
|
301 |
+
# 剔除只包含标点符号的元素
|
302 |
+
clean_sentences = [s for s in merged_japanese_sentences if not is_only_punctuation(s)]
|
303 |
+
|
304 |
+
# 移除空字符串并去除多余引号
|
305 |
+
return [s.replace('"','').strip() for s in clean_sentences if s]
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
# 移除空字符串
|
310 |
+
|
311 |
+
def is_mixed_language(sentence):
|
312 |
+
contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
|
313 |
+
contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
|
314 |
+
contains_english = re.search(r'[a-zA-Z]', sentence) is not None
|
315 |
+
languages_count = sum([contains_chinese, contains_japanese, contains_english])
|
316 |
+
return languages_count > 1
|
317 |
+
|
318 |
+
def split_mixed_language(sentence):
|
319 |
+
# 分割混合语言句子
|
320 |
+
sub_sentences = re.split(r'(?<=[。!?\.\?!])(?=")|(?<=")(?=[\u4e00-\u9fff\u3040-\u30ff\u31f0-\u31ff]|[a-zA-Z])', sentence)
|
321 |
+
return [s.strip() for s in sub_sentences if s.strip()]
|
322 |
+
|
323 |
+
def seconds_to_ass_time(seconds):
|
324 |
+
"""将秒数转换为ASS时间格式"""
|
325 |
+
hours = int(seconds / 3600)
|
326 |
+
minutes = int((seconds % 3600) / 60)
|
327 |
+
seconds = int(seconds) % 60
|
328 |
+
milliseconds = int((seconds - int(seconds)) * 1000)
|
329 |
+
return "{:01d}:{:02d}:{:02d}.{:02d}".format(hours, minutes, seconds, int(milliseconds / 10))
|
330 |
+
|
331 |
+
def extract_text_from_epub(file_path):
|
332 |
+
book = epub.read_epub(file_path)
|
333 |
+
content = []
|
334 |
+
for item in book.items:
|
335 |
+
if isinstance(item, epub.EpubHtml):
|
336 |
+
soup = BeautifulSoup(item.content, 'html.parser')
|
337 |
+
content.append(soup.get_text())
|
338 |
+
return '\n'.join(content)
|
339 |
+
|
340 |
+
def extract_text_from_pdf(file_path):
|
341 |
+
with open(file_path, 'rb') as file:
|
342 |
+
reader = PdfReader(file)
|
343 |
+
content = [page.extract_text() for page in reader.pages]
|
344 |
+
return '\n'.join(content)
|
345 |
+
|
346 |
+
def remove_annotations(text):
|
347 |
+
# 移除方括号、尖括号和中文方括号中的内容
|
348 |
+
text = re.sub(r'\[.*?\]', '', text)
|
349 |
+
text = re.sub(r'\<.*?\>', '', text)
|
350 |
+
text = re.sub(r'​``【oaicite:1】``​', '', text)
|
351 |
+
return text
|
352 |
+
|
353 |
+
def extract_text_from_file(inputFile):
|
354 |
+
file_extension = os.path.splitext(inputFile)[1].lower()
|
355 |
+
if file_extension == ".epub":
|
356 |
+
return extract_text_from_epub(inputFile)
|
357 |
+
elif file_extension == ".pdf":
|
358 |
+
return extract_text_from_pdf(inputFile)
|
359 |
+
elif file_extension == ".txt":
|
360 |
+
with open(inputFile, 'r', encoding='utf-8') as f:
|
361 |
+
return f.read()
|
362 |
+
else:
|
363 |
+
raise ValueError(f"Unsupported file format: {file_extension}")
|
364 |
+
|
365 |
+
def split_by_punctuation(sentence):
|
366 |
+
"""按照中文次级标点符号分割句子"""
|
367 |
+
# 常见的中文次级分隔符号:逗号、分号等
|
368 |
+
parts = re.split(r'([,,;;])', sentence)
|
369 |
+
# 将标点符号与前面的词语合并,避免单独标点符号成为一个部分
|
370 |
+
merged_parts = []
|
371 |
+
for part in parts:
|
372 |
+
if part and not part in ',,;;':
|
373 |
+
merged_parts.append(part)
|
374 |
+
elif merged_parts:
|
375 |
+
merged_parts[-1] += part
|
376 |
+
return merged_parts
|
377 |
+
|
378 |
+
def split_long_sentences(sentence, max_length=30):
|
379 |
+
"""如果中文句子太长,先按标点分割,必要时使用jieba进行分词并分割"""
|
380 |
+
if len(sentence) > max_length and is_chinese(sentence):
|
381 |
+
# 首先尝试按照次级标点符号分割
|
382 |
+
preliminary_parts = split_by_punctuation(sentence)
|
383 |
+
new_sentences = []
|
384 |
+
|
385 |
+
for part in preliminary_parts:
|
386 |
+
# 如果部分仍然太长,使用jieba进行分词
|
387 |
+
if len(part) > max_length:
|
388 |
+
words = jieba.lcut(part)
|
389 |
+
current_sentence = ""
|
390 |
+
for word in words:
|
391 |
+
if len(current_sentence) + len(word) > max_length:
|
392 |
+
new_sentences.append(current_sentence)
|
393 |
+
current_sentence = word
|
394 |
+
else:
|
395 |
+
current_sentence += word
|
396 |
+
if current_sentence:
|
397 |
+
new_sentences.append(current_sentence)
|
398 |
+
else:
|
399 |
+
new_sentences.append(part)
|
400 |
+
|
401 |
+
return new_sentences
|
402 |
+
return [sentence] # 如果句子不长或不是中文,直接返回
|
403 |
+
|
404 |
+
def extract_and_convert(text):
|
405 |
+
|
406 |
+
# 使用正则表达式找出所有英文单词
|
407 |
+
english_parts = re.findall(r'\b[A-Za-z]+\b', text) # \b为单词边界标识
|
408 |
+
|
409 |
+
# 对每个英文单词进行片假名转换
|
410 |
+
kana_parts = ['\n{}\n'.format(romajitable.to_kana(word).katakana) for word in english_parts]
|
411 |
+
|
412 |
+
# 替换原文本中的英文部分
|
413 |
+
for eng, kana in zip(english_parts, kana_parts):
|
414 |
+
text = text.replace(eng, kana, 1) # 限制每次只替换一个实例
|
415 |
+
|
416 |
+
return text
|
417 |
+
# 推理工具
|
418 |
+
def download_unidic():
|
419 |
+
try:
|
420 |
+
Tagger()
|
421 |
+
print("Tagger launch successfully.")
|
422 |
+
except Exception as e:
|
423 |
+
print("UNIDIC dictionary not found, downloading...")
|
424 |
+
subprocess.run([sys.executable, "-m", "unidic", "download"])
|
425 |
+
print("Download completed.")
|
426 |
+
|
427 |
+
def kanji_to_hiragana(text):
|
428 |
+
global tagger
|
429 |
+
output = ""
|
430 |
+
|
431 |
+
# 更新正则表达式以更准确地区分文本和标点符号
|
432 |
+
segments = re.findall(r'[一-龥ぁ-んァ-ン\w]+|[^\一-龥ぁ-んァ-ン\w\s]', text, re.UNICODE)
|
433 |
+
|
434 |
+
for segment in segments:
|
435 |
+
if re.match(r'[一-龥ぁ-んァ-ン\w]+', segment):
|
436 |
+
# 如果是单词或汉字,转换为平假名
|
437 |
+
for word in tagger(segment):
|
438 |
+
kana = word.feature.kana or word.surface
|
439 |
+
hiragana = jaconv.kata2hira(kana) # 将片假名转换为平假名
|
440 |
+
output += hiragana
|
441 |
+
else:
|
442 |
+
# 如果是标点符号,保持不变
|
443 |
+
output += segment
|
444 |
+
|
445 |
+
return output
|
446 |
|
447 |
def get_net_g(model_path: str, device: str, hps):
|
448 |
net_g = SynthesizerTrn(
|
|
|
456 |
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
457 |
return net_g
|
458 |
|
459 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
460 |
+
style_text = None if style_text == "" else style_text
|
461 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
462 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
463 |
+
|
464 |
if hps.data.add_blank:
|
465 |
phone = commons.intersperse(phone, 0)
|
466 |
tone = commons.intersperse(tone, 0)
|
|
|
468 |
for i in range(len(word2ph)):
|
469 |
word2ph[i] = word2ph[i] * 2
|
470 |
word2ph[0] += 1
|
471 |
+
bert_ori = get_bert(
|
472 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
473 |
+
)
|
474 |
del word2ph
|
475 |
assert bert_ori.shape[-1] == len(phone), phone
|
476 |
|
477 |
if language_str == "ZH":
|
478 |
bert = bert_ori
|
479 |
+
ja_bert = torch.randn(1024, len(phone))
|
480 |
+
en_bert = torch.randn(1024, len(phone))
|
481 |
elif language_str == "JP":
|
482 |
+
bert = torch.randn(1024, len(phone))
|
483 |
ja_bert = bert_ori
|
484 |
+
en_bert = torch.randn(1024, len(phone))
|
485 |
+
elif language_str == "EN":
|
486 |
+
bert = torch.randn(1024, len(phone))
|
487 |
+
ja_bert = torch.randn(1024, len(phone))
|
488 |
+
en_bert = bert_ori
|
489 |
else:
|
490 |
raise ValueError("language_str should be ZH, JP or EN")
|
491 |
|
|
|
505 |
noise_scale_w,
|
506 |
length_scale,
|
507 |
sid,
|
508 |
+
style_text=None,
|
509 |
+
style_weight=0.7,
|
510 |
+
language = "Auto",
|
511 |
+
mode = 'pyopenjtalk-V2.3-Katakana',
|
512 |
+
skip_start=False,
|
513 |
+
skip_end=False,
|
514 |
):
|
515 |
+
if style_text == None:
|
516 |
+
style_text = ""
|
517 |
+
style_weight=0,
|
518 |
+
if mode == 'fugashi-V2.3-Katakana':
|
519 |
+
text = kanji_to_hiragana(text) if is_japanese(text) else text
|
520 |
+
if language == "JP":
|
521 |
+
text = translate(text,"jp")
|
522 |
+
if language == "ZH":
|
523 |
+
text = translate(text,"zh")
|
524 |
+
if language == "Auto":
|
525 |
+
language= 'JP' if is_japanese(text) else 'ZH'
|
526 |
+
#print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{sid}:{language}:{mode}:{skip_start}:{skip_end}')
|
527 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
528 |
+
text,
|
529 |
+
language,
|
530 |
+
hps,
|
531 |
+
device,
|
532 |
+
style_text=style_text,
|
533 |
+
style_weight=style_weight,
|
534 |
)
|
535 |
+
if skip_start:
|
536 |
+
phones = phones[3:]
|
537 |
+
tones = tones[3:]
|
538 |
+
lang_ids = lang_ids[3:]
|
539 |
+
bert = bert[:, 3:]
|
540 |
+
ja_bert = ja_bert[:, 3:]
|
541 |
+
en_bert = en_bert[:, 3:]
|
542 |
+
if skip_end:
|
543 |
+
phones = phones[:-2]
|
544 |
+
tones = tones[:-2]
|
545 |
+
lang_ids = lang_ids[:-2]
|
546 |
+
bert = bert[:, :-2]
|
547 |
+
ja_bert = ja_bert[:, :-2]
|
548 |
+
en_bert = en_bert[:, :-2]
|
549 |
with torch.no_grad():
|
550 |
x_tst = phones.to(device).unsqueeze(0)
|
551 |
tones = tones.to(device).unsqueeze(0)
|
|
|
554 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
555 |
en_bert = en_bert.to(device).unsqueeze(0)
|
556 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
557 |
+
# emo = emo.to(device).unsqueeze(0)
|
558 |
del phones
|
559 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
560 |
audio = (
|
|
|
567 |
bert,
|
568 |
ja_bert,
|
569 |
en_bert,
|
|
|
570 |
sdp_ratio=sdp_ratio,
|
571 |
noise_scale=noise_scale,
|
572 |
noise_scale_w=noise_scale_w,
|
|
|
576 |
.float()
|
577 |
.numpy()
|
578 |
)
|
579 |
+
del (
|
580 |
+
x_tst,
|
581 |
+
tones,
|
582 |
+
lang_ids,
|
583 |
+
bert,
|
584 |
+
x_tst_lengths,
|
585 |
+
speakers,
|
586 |
+
ja_bert,
|
587 |
+
en_bert,
|
588 |
+
) # , emo
|
589 |
if torch.cuda.is_available():
|
590 |
torch.cuda.empty_cache()
|
591 |
+
print("Success.")
|
592 |
+
return audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
def loadmodel(model):
|
595 |
+
_ = net_g.eval()
|
596 |
+
_ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
|
597 |
+
return "success"
|
598 |
+
|
599 |
+
def generate_audio_and_srt_for_group(
|
600 |
+
group,
|
601 |
+
outputPath,
|
602 |
+
group_index,
|
603 |
+
sampling_rate,
|
604 |
+
speaker,
|
605 |
+
sdp_ratio,
|
606 |
+
noise_scale,
|
607 |
+
noise_scale_w,
|
608 |
+
length_scale,
|
609 |
+
speakerList,
|
610 |
+
silenceTime,
|
611 |
+
language,
|
612 |
+
mode,
|
613 |
+
skip_start,
|
614 |
+
skip_end,
|
615 |
+
style_text,
|
616 |
+
style_weight,
|
617 |
+
):
|
618 |
+
audio_fin = []
|
619 |
+
ass_entries = []
|
620 |
+
start_time = 0
|
621 |
+
#speaker = random.choice(cara_list)
|
622 |
+
ass_header = """[Script Info]
|
623 |
+
; 我没意见
|
624 |
+
Title: Audiobook
|
625 |
+
ScriptType: v4.00+
|
626 |
+
WrapStyle: 0
|
627 |
+
PlayResX: 640
|
628 |
+
PlayResY: 360
|
629 |
+
ScaledBorderAndShadow: yes
|
630 |
+
[V4+ Styles]
|
631 |
+
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
|
632 |
+
Style: Default,Arial,20,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,1,1,2,10,10,10,1
|
633 |
+
[Events]
|
634 |
+
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
|
635 |
+
"""
|
636 |
+
|
637 |
+
for sentence in group:
|
638 |
+
|
639 |
+
if len(sentence) > 1:
|
640 |
+
FakeSpeaker = sentence.split("|")[0]
|
641 |
+
print(FakeSpeaker)
|
642 |
+
SpeakersList = re.split('\n', speakerList)
|
643 |
+
if FakeSpeaker in list(hps.data.spk2id.keys()):
|
644 |
+
speaker = FakeSpeaker
|
645 |
+
for i in SpeakersList:
|
646 |
+
if FakeSpeaker == i.split("|")[1]:
|
647 |
+
speaker = i.split("|")[0]
|
648 |
+
if sentence != '\n':
|
649 |
+
text = (remove_annotations(sentence.split("|")[-1]).replace(" ","")+"。").replace(",。","。")
|
650 |
+
if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
|
651 |
+
#print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{speaker}:{language}:{mode}:{skip_start}:{skip_end}')
|
652 |
+
audio = infer(
|
653 |
+
text,
|
654 |
+
sdp_ratio,
|
655 |
+
noise_scale,
|
656 |
+
noise_scale_w,
|
657 |
+
length_scale,
|
658 |
+
speaker,
|
659 |
+
style_text,
|
660 |
+
style_weight,
|
661 |
+
language,
|
662 |
+
mode,
|
663 |
+
skip_start,
|
664 |
+
skip_end,
|
665 |
+
)
|
666 |
+
silence_frames = int(silenceTime * 44010) if is_chinese(sentence) else int(silenceTime * 44010)
|
667 |
+
silence_data = np.zeros((silence_frames,), dtype=audio.dtype)
|
668 |
+
audio_fin.append(audio)
|
669 |
+
audio_fin.append(silence_data)
|
670 |
+
duration = len(audio) / sampling_rate
|
671 |
+
print(duration)
|
672 |
+
end_time = start_time + duration + silenceTime
|
673 |
+
ass_entries.append("Dialogue: 0,{},{},".format(seconds_to_ass_time(start_time), seconds_to_ass_time(end_time)) + "Default,,0,0,0,,{}".format(sentence.replace("|",":")))
|
674 |
+
start_time = end_time
|
675 |
+
|
676 |
+
wav_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.wav')
|
677 |
+
ass_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.ass')
|
678 |
+
write(wav_filename, sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
|
679 |
|
680 |
+
with open(ass_filename, 'w', encoding='utf-8') as f:
|
681 |
+
f.write(ass_header + '\n'.join(ass_entries))
|
682 |
+
return (hps.data.sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
|
683 |
+
|
684 |
+
def generate_audio(
|
685 |
+
inputFile,
|
686 |
+
groupsize,
|
687 |
+
filepath,
|
688 |
+
silenceTime,
|
689 |
+
speakerList,
|
690 |
+
text,
|
691 |
+
sdp_ratio,
|
692 |
+
noise_scale,
|
693 |
+
noise_scale_w,
|
694 |
+
length_scale,
|
695 |
+
sid,
|
696 |
+
style_text=None,
|
697 |
+
style_weight=0.7,
|
698 |
+
language = "Auto",
|
699 |
+
mode = 'pyopenjtalk-V2.3-Katakana',
|
700 |
+
sentence_mode = 'sentence',
|
701 |
+
skip_start=False,
|
702 |
+
skip_end=False,
|
703 |
+
):
|
704 |
+
if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
|
705 |
+
if sentence_mode == 'sentence':
|
706 |
+
audio = infer(
|
707 |
+
text,
|
708 |
+
sdp_ratio,
|
709 |
+
noise_scale,
|
710 |
+
noise_scale_w,
|
711 |
+
length_scale,
|
712 |
+
sid,
|
713 |
+
style_text,
|
714 |
+
style_weight,
|
715 |
+
language,
|
716 |
+
mode,
|
717 |
+
skip_start,
|
718 |
+
skip_end,
|
719 |
+
)
|
720 |
+
return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
|
721 |
+
if sentence_mode == 'paragraph':
|
722 |
+
GROUP_SIZE = groupsize
|
723 |
+
directory_path = filepath if torch.cuda.is_available() else "books"
|
724 |
+
if os.path.exists(directory_path):
|
725 |
+
shutil.rmtree(directory_path)
|
726 |
+
os.makedirs(directory_path)
|
727 |
+
if inputFile:
|
728 |
+
text = extract_text_from_file(inputFile.name)
|
729 |
+
if language == 'Auto':
|
730 |
+
sentences = extrac(extract_and_convert(text))
|
731 |
+
else:
|
732 |
+
sentences = extrac(text)
|
733 |
+
for i in range(0, len(sentences), GROUP_SIZE):
|
734 |
+
group = sentences[i:i+GROUP_SIZE]
|
735 |
+
if speakerList == "":
|
736 |
+
speakerList = "无"
|
737 |
+
result = generate_audio_and_srt_for_group(
|
738 |
+
group,
|
739 |
+
directory_path,
|
740 |
+
i//GROUP_SIZE + 1,
|
741 |
+
44100,
|
742 |
+
sid,
|
743 |
+
sdp_ratio,
|
744 |
+
noise_scale,
|
745 |
+
noise_scale_w,
|
746 |
+
length_scale,
|
747 |
+
speakerList,
|
748 |
+
silenceTime,
|
749 |
+
language,
|
750 |
+
mode,
|
751 |
+
skip_start,
|
752 |
+
skip_end,
|
753 |
+
style_text,
|
754 |
+
style_weight,
|
755 |
+
)
|
756 |
+
if not torch.cuda.is_available():
|
757 |
+
return result
|
758 |
+
return result
|
759 |
+
|
760 |
+
Flaskapp = Flask(__name__)
|
761 |
+
CORS(Flaskapp)
|
762 |
+
@Flaskapp.route('/', methods=['GET', 'POST'])
|
763 |
|
764 |
def tts():
|
765 |
+
if request.method == 'POST':
|
766 |
+
input = request.json
|
767 |
+
inputFile = None
|
768 |
+
filepath = input['filepath']
|
769 |
+
groupSize = input['groupSize']
|
770 |
+
text = input['text']
|
771 |
+
sdp_ratio = input['sdp_ratio']
|
772 |
+
noise_scale = input['noise_scale']
|
773 |
+
noise_scale_w = input['noise_scale_w']
|
774 |
+
length_scale = input['length_scale']
|
775 |
+
sid = input['speaker']
|
776 |
+
style_text = input['style_text']
|
777 |
+
style_weight = input['style_weight']
|
778 |
+
language = input['language']
|
779 |
+
mode = input['mode']
|
780 |
+
sentence_mode = input['sentence_mode']
|
781 |
+
skip_start = input['skip_start']
|
782 |
+
skip_end = input['skip_end']
|
783 |
+
speakerList = input['speakerList']
|
784 |
+
silenceTime = input['silenceTime']
|
785 |
+
samplerate, audio = generate_audio(
|
786 |
+
inputFile,
|
787 |
+
groupSize,
|
788 |
+
filepath,
|
789 |
+
silenceTime,
|
790 |
+
speakerList,
|
791 |
+
text,
|
792 |
+
sdp_ratio,
|
793 |
+
noise_scale,
|
794 |
+
noise_scale_w,
|
795 |
+
length_scale,
|
796 |
+
sid,
|
797 |
+
style_text,
|
798 |
+
style_weight,
|
799 |
+
language,
|
800 |
+
mode,
|
801 |
+
sentence_mode,
|
802 |
+
skip_start,
|
803 |
+
skip_end,
|
804 |
+
)
|
805 |
+
unique_filename = f"temp{uuid.uuid4()}.wav"
|
806 |
+
write(unique_filename, samplerate, audio)
|
807 |
+
with open(unique_filename ,'rb') as bit:
|
808 |
+
wav_bytes = bit.read()
|
809 |
+
os.remove(unique_filename)
|
810 |
+
headers = {
|
811 |
+
'Content-Type': 'audio/wav',
|
812 |
+
'Text': unique_filename .encode('utf-8')}
|
813 |
+
return wav_bytes, 200, headers
|
814 |
+
groupSize = request.args.get('groupSize', default = 50, type = int)
|
815 |
+
text = request.args.get('text', default = '', type = str)
|
816 |
+
sdp_ratio = request.args.get('sdp_ratio', default = 0.5, type = float)
|
817 |
+
noise_scale = request.args.get('noise_scale', default = 0.6, type = float)
|
818 |
+
noise_scale_w = request.args.get('noise_scale_w', default = 0.667, type = float)
|
819 |
+
length_scale = request.args.get('length_scale', default = 1, type = float)
|
820 |
+
sid = request.args.get('speaker', default = '八千代', type = str)
|
821 |
+
style_text = request.args.get('style_text', default = '', type = str)
|
822 |
+
style_weight = request.args.get('style_weight', default = 0.7, type = float)
|
823 |
+
language = request.args.get('language', default = 'Auto', type = str)
|
824 |
+
mode = request.args.get('mode', default = 'pyopenjtalk-V2.3-Katakana', type = str)
|
825 |
+
sentence_mode = request.args.get('sentence_mode', default = 'sentence', type = str)
|
826 |
+
skip_start = request.args.get('skip_start', default = False, type = bool)
|
827 |
+
skip_end = request.args.get('skip_end', default = False, type = bool)
|
828 |
+
speakerList = request.args.get('speakerList', default = '', type = str)
|
829 |
+
silenceTime = request.args.get('silenceTime', default = 0.1, type = float)
|
830 |
+
inputFile = None
|
831 |
+
if not sid or not text:
|
832 |
+
return render_template_string(f"""
|
833 |
+
<!DOCTYPE html>
|
834 |
+
<html>
|
835 |
+
<head>
|
836 |
+
<title>TTS API Documentation</title>
|
837 |
+
</head>
|
838 |
+
<body>
|
839 |
+
<iframe src={webBase} style="width:100%; height:100vh; border:none;"></iframe>
|
840 |
+
</body>
|
841 |
+
</html>
|
842 |
+
""")
|
843 |
+
samplerate, audio = generate_audio(
|
844 |
+
inputFile,
|
845 |
+
groupSize,
|
846 |
+
None,
|
847 |
+
silenceTime,
|
848 |
+
speakerList,
|
849 |
+
text,
|
850 |
+
sdp_ratio,
|
851 |
+
noise_scale,
|
852 |
+
noise_scale_w,
|
853 |
+
length_scale,
|
854 |
+
sid,
|
855 |
+
style_text,
|
856 |
+
style_weight,
|
857 |
+
language,
|
858 |
+
mode,
|
859 |
+
sentence_mode,
|
860 |
+
skip_start,
|
861 |
+
skip_end,
|
862 |
+
)
|
863 |
+
unique_filename = f"temp{uuid.uuid4()}.wav"
|
864 |
+
write(unique_filename, samplerate, audio)
|
865 |
with open(unique_filename ,'rb') as bit:
|
866 |
wav_bytes = bit.read()
|
867 |
os.remove(unique_filename)
|
|
|
872 |
|
873 |
|
874 |
if __name__ == "__main__":
|
875 |
+
download_unidic()
|
876 |
+
tagger = Tagger()
|
|
|
|
|
|
|
|
|
877 |
net_g = get_net_g(
|
878 |
model_path=modelPaths[-1], device=device, hps=hps
|
879 |
)
|
880 |
speaker_ids = hps.data.spk2id
|
881 |
speakers = list(speaker_ids.keys())
|
882 |
+
|
883 |
+
print("推理页面已开启!")
|
884 |
+
Flaskapp.run(host="0.0.0.0", port=8080,debug=True)
|
test.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from fugashi import Tagger
|
3 |
+
import jaconv
|
4 |
+
|
5 |
+
def kanji_to_hiragana(text):
|
6 |
+
tagger = Tagger()
|
7 |
+
output = ""
|
8 |
+
|
9 |
+
# 更新正则表达式以更准确地区分文本和标点符号
|
10 |
+
segments = re.findall(r'[一-龥ぁ-んァ-ン\w]+|[^\一-龥ぁ-んァ-ン\w\s]', text, re.UNICODE)
|
11 |
+
|
12 |
+
for segment in segments:
|
13 |
+
if re.match(r'[一-龥ぁ-んァ-ン\w]+', segment):
|
14 |
+
# 如果是单词或汉字,转换为平假名
|
15 |
+
for word in tagger(segment):
|
16 |
+
kana = word.feature.kana or word.surface
|
17 |
+
hiragana = jaconv.kata2hira(kana) # 将片假名转换为平假名
|
18 |
+
output += hiragana
|
19 |
+
else:
|
20 |
+
# 如果是标点符号,保持不变
|
21 |
+
output += segment
|
22 |
+
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
text = "私は学生です。"
|
27 |
+
tagger = Tagger()
|
28 |
+
|
29 |
+
for word in tagger(text):
|
30 |
+
print(word.surface, word.feature.pos1)
|
31 |
+
|
32 |
+
|
33 |
+
# 示例文本
|
34 |
+
text = "業火とはね、どんな人でも彼女が築いた悪業は、いつの日か、彼女を少しも残さず焼き払うことになる……"
|
35 |
+
converted_text = kanji_to_hiragana(text)
|
36 |
+
print(converted_text)
|
train_ms.py
CHANGED
@@ -13,7 +13,6 @@ import logging
|
|
13 |
from config import config
|
14 |
import argparse
|
15 |
import datetime
|
16 |
-
import gc
|
17 |
|
18 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
19 |
import commons
|
@@ -27,14 +26,21 @@ from models import (
|
|
27 |
SynthesizerTrn,
|
28 |
MultiPeriodDiscriminator,
|
29 |
DurationDiscriminator,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
31 |
-
from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
32 |
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
33 |
from text.symbols import symbols
|
34 |
|
35 |
torch.backends.cuda.matmul.allow_tf32 = True
|
36 |
torch.backends.cudnn.allow_tf32 = (
|
37 |
-
True # If
|
38 |
)
|
39 |
torch.set_float32_matmul_precision("medium")
|
40 |
torch.backends.cuda.sdp_kernel("flash")
|
@@ -42,7 +48,6 @@ torch.backends.cuda.enable_flash_sdp(True)
|
|
42 |
torch.backends.cuda.enable_mem_efficient_sdp(
|
43 |
True
|
44 |
) # Not available if torch version is lower than 2.0
|
45 |
-
torch.backends.cuda.enable_math_sdp(True)
|
46 |
global_step = 0
|
47 |
|
48 |
|
@@ -97,7 +102,7 @@ def run():
|
|
97 |
args = parser.parse_args()
|
98 |
model_dir = os.path.join(args.model, config.train_ms_config.model)
|
99 |
if not os.path.exists(model_dir):
|
100 |
-
os.makedirs(model_dir)
|
101 |
hps = utils.get_hparams_from_file(args.config)
|
102 |
hps.model_dir = model_dir
|
103 |
# 比较路径是否相同
|
@@ -173,6 +178,8 @@ def run():
|
|
173 |
0.1,
|
174 |
gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
|
175 |
).cuda(local_rank)
|
|
|
|
|
176 |
if (
|
177 |
"use_spk_conditioned_encoder" in hps.model.keys()
|
178 |
and hps.model.use_spk_conditioned_encoder is True
|
@@ -210,6 +217,9 @@ def run():
|
|
210 |
param.requires_grad = False
|
211 |
|
212 |
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
|
|
|
|
|
|
|
213 |
optim_g = torch.optim.AdamW(
|
214 |
filter(lambda p: p.requires_grad, net_g.parameters()),
|
215 |
hps.train.learning_rate,
|
@@ -222,6 +232,12 @@ def run():
|
|
222 |
betas=hps.train.betas,
|
223 |
eps=hps.train.eps,
|
224 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
if net_dur_disc is not None:
|
226 |
optim_dur_disc = torch.optim.AdamW(
|
227 |
net_dur_disc.parameters(),
|
@@ -233,12 +249,11 @@ def run():
|
|
233 |
optim_dur_disc = None
|
234 |
net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
|
235 |
net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
|
236 |
-
|
237 |
if net_dur_disc is not None:
|
238 |
net_dur_disc = DDP(
|
239 |
net_dur_disc,
|
240 |
device_ids=[local_rank],
|
241 |
-
find_unused_parameters=True,
|
242 |
bucket_cap_mb=512,
|
243 |
)
|
244 |
|
@@ -250,9 +265,10 @@ def run():
|
|
250 |
token=config.openi_token,
|
251 |
mirror=config.mirror,
|
252 |
)
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
256 |
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
|
257 |
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
|
258 |
net_dur_disc,
|
@@ -261,28 +277,32 @@ def run():
|
|
261 |
if "skip_optimizer" in hps.train
|
262 |
else True,
|
263 |
)
|
264 |
-
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
|
265 |
-
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
|
266 |
-
net_g,
|
267 |
-
optim_g,
|
268 |
-
skip_optimizer=hps.train.skip_optimizer
|
269 |
-
if "skip_optimizer" in hps.train
|
270 |
-
else True,
|
271 |
-
)
|
272 |
-
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
|
273 |
-
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
|
274 |
-
net_d,
|
275 |
-
optim_d,
|
276 |
-
skip_optimizer=hps.train.skip_optimizer
|
277 |
-
if "skip_optimizer" in hps.train
|
278 |
-
else True,
|
279 |
-
)
|
280 |
-
if not optim_g.param_groups[0].get("initial_lr"):
|
281 |
-
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
|
282 |
-
if not optim_d.param_groups[0].get("initial_lr"):
|
283 |
-
optim_d.param_groups[0]["initial_lr"] = d_resume_lr
|
284 |
if not optim_dur_disc.param_groups[0].get("initial_lr"):
|
285 |
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
epoch_str = max(epoch_str, 1)
|
288 |
# global_step = (epoch_str - 1) * len(train_loader)
|
@@ -297,21 +317,43 @@ def run():
|
|
297 |
epoch_str = 1
|
298 |
global_step = 0
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
301 |
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
302 |
)
|
303 |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
304 |
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
305 |
)
|
|
|
|
|
|
|
306 |
if net_dur_disc is not None:
|
307 |
-
if not optim_dur_disc.param_groups[0].get("initial_lr"):
|
308 |
-
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
|
309 |
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
|
310 |
optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
311 |
)
|
312 |
else:
|
313 |
scheduler_dur_disc = None
|
314 |
-
scaler = GradScaler(enabled=hps.train.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
317 |
if rank == 0:
|
@@ -320,9 +362,9 @@ def run():
|
|
320 |
local_rank,
|
321 |
epoch,
|
322 |
hps,
|
323 |
-
[net_g, net_d, net_dur_disc],
|
324 |
-
[optim_g, optim_d, optim_dur_disc],
|
325 |
-
[scheduler_g, scheduler_d, scheduler_dur_disc],
|
326 |
scaler,
|
327 |
[train_loader, eval_loader],
|
328 |
logger,
|
@@ -334,9 +376,9 @@ def run():
|
|
334 |
local_rank,
|
335 |
epoch,
|
336 |
hps,
|
337 |
-
[net_g, net_d, net_dur_disc],
|
338 |
-
[optim_g, optim_d, optim_dur_disc],
|
339 |
-
[scheduler_g, scheduler_d, scheduler_dur_disc],
|
340 |
scaler,
|
341 |
[train_loader, None],
|
342 |
None,
|
@@ -344,6 +386,7 @@ def run():
|
|
344 |
)
|
345 |
scheduler_g.step()
|
346 |
scheduler_d.step()
|
|
|
347 |
if net_dur_disc is not None:
|
348 |
scheduler_dur_disc.step()
|
349 |
|
@@ -361,9 +404,9 @@ def train_and_evaluate(
|
|
361 |
logger,
|
362 |
writers,
|
363 |
):
|
364 |
-
net_g, net_d, net_dur_disc = nets
|
365 |
-
optim_g, optim_d, optim_dur_disc = optims
|
366 |
-
scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
|
367 |
train_loader, eval_loader = loaders
|
368 |
if writers is not None:
|
369 |
writer, writer_eval = writers
|
@@ -373,6 +416,7 @@ def train_and_evaluate(
|
|
373 |
|
374 |
net_g.train()
|
375 |
net_d.train()
|
|
|
376 |
if net_dur_disc is not None:
|
377 |
net_dur_disc.train()
|
378 |
for batch_idx, (
|
@@ -388,7 +432,6 @@ def train_and_evaluate(
|
|
388 |
bert,
|
389 |
ja_bert,
|
390 |
en_bert,
|
391 |
-
emo,
|
392 |
) in enumerate(tqdm(train_loader)):
|
393 |
if net_g.module.use_noise_scaled_mas:
|
394 |
current_mas_noise_scale = (
|
@@ -411,9 +454,8 @@ def train_and_evaluate(
|
|
411 |
bert = bert.cuda(local_rank, non_blocking=True)
|
412 |
ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
|
413 |
en_bert = en_bert.cuda(local_rank, non_blocking=True)
|
414 |
-
emo = emo.cuda(local_rank, non_blocking=True)
|
415 |
|
416 |
-
with autocast(enabled=hps.train.
|
417 |
(
|
418 |
y_hat,
|
419 |
l_length,
|
@@ -422,9 +464,8 @@ def train_and_evaluate(
|
|
422 |
x_mask,
|
423 |
z_mask,
|
424 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
425 |
-
(hidden_x, logw, logw_),
|
426 |
g,
|
427 |
-
loss_commit,
|
428 |
) = net_g(
|
429 |
x,
|
430 |
x_lengths,
|
@@ -436,7 +477,6 @@ def train_and_evaluate(
|
|
436 |
bert,
|
437 |
ja_bert,
|
438 |
en_bert,
|
439 |
-
emo,
|
440 |
)
|
441 |
mel = spec_to_mel_torch(
|
442 |
spec,
|
@@ -450,7 +490,7 @@ def train_and_evaluate(
|
|
450 |
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
451 |
)
|
452 |
y_hat_mel = mel_spectrogram_torch(
|
453 |
-
y_hat.squeeze(1),
|
454 |
hps.data.filter_length,
|
455 |
hps.data.n_mel_channels,
|
456 |
hps.data.sampling_rate,
|
@@ -466,7 +506,7 @@ def train_and_evaluate(
|
|
466 |
|
467 |
# Discriminator
|
468 |
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
469 |
-
with autocast(enabled=
|
470 |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
471 |
y_d_hat_r, y_d_hat_g
|
472 |
)
|
@@ -475,11 +515,20 @@ def train_and_evaluate(
|
|
475 |
y_dur_hat_r, y_dur_hat_g = net_dur_disc(
|
476 |
hidden_x.detach(),
|
477 |
x_mask.detach(),
|
|
|
478 |
logw.detach(),
|
|
|
|
|
|
|
|
|
|
|
479 |
logw_.detach(),
|
|
|
480 |
g.detach(),
|
481 |
)
|
482 |
-
|
|
|
|
|
483 |
# TODO: I think need to mean using the mask, but for now, just mean all
|
484 |
(
|
485 |
loss_dur_disc,
|
@@ -490,31 +539,60 @@ def train_and_evaluate(
|
|
490 |
optim_dur_disc.zero_grad()
|
491 |
scaler.scale(loss_dur_disc_all).backward()
|
492 |
scaler.unscale_(optim_dur_disc)
|
493 |
-
|
|
|
|
|
|
|
|
|
|
|
494 |
scaler.step(optim_dur_disc)
|
495 |
|
496 |
optim_d.zero_grad()
|
497 |
scaler.scale(loss_disc_all).backward()
|
498 |
scaler.unscale_(optim_d)
|
|
|
|
|
499 |
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
500 |
scaler.step(optim_d)
|
501 |
|
502 |
-
with autocast(enabled=hps.train.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
# Generator
|
504 |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
505 |
if net_dur_disc is not None:
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
with autocast(enabled=
|
510 |
loss_dur = torch.sum(l_length.float())
|
511 |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
512 |
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
513 |
|
514 |
loss_fm = feature_loss(fmap_r, fmap_g)
|
515 |
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
|
|
|
|
|
|
|
|
516 |
loss_gen_all = (
|
517 |
-
loss_gen
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
)
|
519 |
if net_dur_disc is not None:
|
520 |
loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
|
@@ -522,6 +600,8 @@ def train_and_evaluate(
|
|
522 |
optim_g.zero_grad()
|
523 |
scaler.scale(loss_gen_all).backward()
|
524 |
scaler.unscale_(optim_g)
|
|
|
|
|
525 |
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
526 |
scaler.step(optim_g)
|
527 |
scaler.update()
|
@@ -540,9 +620,12 @@ def train_and_evaluate(
|
|
540 |
scalar_dict = {
|
541 |
"loss/g/total": loss_gen_all,
|
542 |
"loss/d/total": loss_disc_all,
|
|
|
543 |
"learning_rate": lr,
|
544 |
"grad_norm_d": grad_norm_d,
|
545 |
"grad_norm_g": grad_norm_g,
|
|
|
|
|
546 |
}
|
547 |
scalar_dict.update(
|
548 |
{
|
@@ -550,6 +633,8 @@ def train_and_evaluate(
|
|
550 |
"loss/g/mel": loss_mel,
|
551 |
"loss/g/dur": loss_dur,
|
552 |
"loss/g/kl": loss_kl,
|
|
|
|
|
553 |
}
|
554 |
)
|
555 |
scalar_dict.update(
|
@@ -562,6 +647,30 @@ def train_and_evaluate(
|
|
562 |
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
563 |
)
|
564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
image_dict = {
|
566 |
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
567 |
y_mel[0].data.cpu().numpy()
|
@@ -599,6 +708,13 @@ def train_and_evaluate(
|
|
599 |
epoch,
|
600 |
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
601 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
if net_dur_disc is not None:
|
603 |
utils.save_checkpoint(
|
604 |
net_dur_disc,
|
@@ -617,8 +733,8 @@ def train_and_evaluate(
|
|
617 |
|
618 |
global_step += 1
|
619 |
|
620 |
-
gc.collect()
|
621 |
-
torch.cuda.empty_cache()
|
622 |
if rank == 0:
|
623 |
logger.info("====> Epoch: {}".format(epoch))
|
624 |
|
@@ -642,7 +758,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|
642 |
bert,
|
643 |
ja_bert,
|
644 |
en_bert,
|
645 |
-
emo,
|
646 |
) in enumerate(eval_loader):
|
647 |
x, x_lengths = x.cuda(), x_lengths.cuda()
|
648 |
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
@@ -653,7 +768,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|
653 |
en_bert = en_bert.cuda()
|
654 |
tone = tone.cuda()
|
655 |
language = language.cuda()
|
656 |
-
emo = emo.cuda()
|
657 |
for use_sdp in [True, False]:
|
658 |
y_hat, attn, mask, *_ = generator.module.infer(
|
659 |
x,
|
@@ -664,7 +778,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|
664 |
bert,
|
665 |
ja_bert,
|
666 |
en_bert,
|
667 |
-
emo,
|
668 |
y=spec,
|
669 |
max_len=1000,
|
670 |
sdp_ratio=0.0 if not use_sdp else 1.0,
|
|
|
13 |
from config import config
|
14 |
import argparse
|
15 |
import datetime
|
|
|
16 |
|
17 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
18 |
import commons
|
|
|
26 |
SynthesizerTrn,
|
27 |
MultiPeriodDiscriminator,
|
28 |
DurationDiscriminator,
|
29 |
+
WavLMDiscriminator,
|
30 |
+
)
|
31 |
+
from losses import (
|
32 |
+
generator_loss,
|
33 |
+
discriminator_loss,
|
34 |
+
feature_loss,
|
35 |
+
kl_loss,
|
36 |
+
WavLMLoss,
|
37 |
)
|
|
|
38 |
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
39 |
from text.symbols import symbols
|
40 |
|
41 |
torch.backends.cuda.matmul.allow_tf32 = True
|
42 |
torch.backends.cudnn.allow_tf32 = (
|
43 |
+
True # If encountered training problem,please try to disable TF32.
|
44 |
)
|
45 |
torch.set_float32_matmul_precision("medium")
|
46 |
torch.backends.cuda.sdp_kernel("flash")
|
|
|
48 |
torch.backends.cuda.enable_mem_efficient_sdp(
|
49 |
True
|
50 |
) # Not available if torch version is lower than 2.0
|
|
|
51 |
global_step = 0
|
52 |
|
53 |
|
|
|
102 |
args = parser.parse_args()
|
103 |
model_dir = os.path.join(args.model, config.train_ms_config.model)
|
104 |
if not os.path.exists(model_dir):
|
105 |
+
os.makedirs(model_dir, exist_ok=True)
|
106 |
hps = utils.get_hparams_from_file(args.config)
|
107 |
hps.model_dir = model_dir
|
108 |
# 比较路径是否相同
|
|
|
178 |
0.1,
|
179 |
gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
|
180 |
).cuda(local_rank)
|
181 |
+
else:
|
182 |
+
net_dur_disc = None
|
183 |
if (
|
184 |
"use_spk_conditioned_encoder" in hps.model.keys()
|
185 |
and hps.model.use_spk_conditioned_encoder is True
|
|
|
217 |
param.requires_grad = False
|
218 |
|
219 |
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
|
220 |
+
net_wd = WavLMDiscriminator(
|
221 |
+
hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel
|
222 |
+
).cuda(local_rank)
|
223 |
optim_g = torch.optim.AdamW(
|
224 |
filter(lambda p: p.requires_grad, net_g.parameters()),
|
225 |
hps.train.learning_rate,
|
|
|
232 |
betas=hps.train.betas,
|
233 |
eps=hps.train.eps,
|
234 |
)
|
235 |
+
optim_wd = torch.optim.AdamW(
|
236 |
+
net_wd.parameters(),
|
237 |
+
hps.train.learning_rate,
|
238 |
+
betas=hps.train.betas,
|
239 |
+
eps=hps.train.eps,
|
240 |
+
)
|
241 |
if net_dur_disc is not None:
|
242 |
optim_dur_disc = torch.optim.AdamW(
|
243 |
net_dur_disc.parameters(),
|
|
|
249 |
optim_dur_disc = None
|
250 |
net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
|
251 |
net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
|
252 |
+
net_wd = DDP(net_wd, device_ids=[local_rank], bucket_cap_mb=512)
|
253 |
if net_dur_disc is not None:
|
254 |
net_dur_disc = DDP(
|
255 |
net_dur_disc,
|
256 |
device_ids=[local_rank],
|
|
|
257 |
bucket_cap_mb=512,
|
258 |
)
|
259 |
|
|
|
265 |
token=config.openi_token,
|
266 |
mirror=config.mirror,
|
267 |
)
|
268 |
+
dur_resume_lr = hps.train.learning_rate
|
269 |
+
wd_resume_lr = hps.train.learning_rate
|
270 |
+
if net_dur_disc is not None:
|
271 |
+
try:
|
272 |
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
|
273 |
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
|
274 |
net_dur_disc,
|
|
|
277 |
if "skip_optimizer" in hps.train
|
278 |
else True,
|
279 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
if not optim_dur_disc.param_groups[0].get("initial_lr"):
|
281 |
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
|
282 |
+
except:
|
283 |
+
print("Initialize dur_disc")
|
284 |
+
|
285 |
+
try:
|
286 |
+
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
|
287 |
+
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
|
288 |
+
net_g,
|
289 |
+
optim_g,
|
290 |
+
skip_optimizer=hps.train.skip_optimizer
|
291 |
+
if "skip_optimizer" in hps.train
|
292 |
+
else True,
|
293 |
+
)
|
294 |
+
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
|
295 |
+
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
|
296 |
+
net_d,
|
297 |
+
optim_d,
|
298 |
+
skip_optimizer=hps.train.skip_optimizer
|
299 |
+
if "skip_optimizer" in hps.train
|
300 |
+
else True,
|
301 |
+
)
|
302 |
+
if not optim_g.param_groups[0].get("initial_lr"):
|
303 |
+
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
|
304 |
+
if not optim_d.param_groups[0].get("initial_lr"):
|
305 |
+
optim_d.param_groups[0]["initial_lr"] = d_resume_lr
|
306 |
|
307 |
epoch_str = max(epoch_str, 1)
|
308 |
# global_step = (epoch_str - 1) * len(train_loader)
|
|
|
317 |
epoch_str = 1
|
318 |
global_step = 0
|
319 |
|
320 |
+
try:
|
321 |
+
_, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
|
322 |
+
utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
|
323 |
+
net_wd,
|
324 |
+
optim_wd,
|
325 |
+
skip_optimizer=hps.train.skip_optimizer
|
326 |
+
if "skip_optimizer" in hps.train
|
327 |
+
else True,
|
328 |
+
)
|
329 |
+
if not optim_wd.param_groups[0].get("initial_lr"):
|
330 |
+
optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
|
331 |
+
except Exception as e:
|
332 |
+
print(e)
|
333 |
+
|
334 |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
335 |
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
336 |
)
|
337 |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
338 |
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
339 |
)
|
340 |
+
scheduler_wd = torch.optim.lr_scheduler.ExponentialLR(
|
341 |
+
optim_wd, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
342 |
+
)
|
343 |
if net_dur_disc is not None:
|
|
|
|
|
344 |
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
|
345 |
optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
346 |
)
|
347 |
else:
|
348 |
scheduler_dur_disc = None
|
349 |
+
scaler = GradScaler(enabled=hps.train.bf16_run)
|
350 |
+
|
351 |
+
wl = WavLMLoss(
|
352 |
+
hps.model.slm.model,
|
353 |
+
net_wd,
|
354 |
+
hps.data.sampling_rate,
|
355 |
+
hps.model.slm.sr,
|
356 |
+
).to(local_rank)
|
357 |
|
358 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
359 |
if rank == 0:
|
|
|
362 |
local_rank,
|
363 |
epoch,
|
364 |
hps,
|
365 |
+
[net_g, net_d, net_dur_disc, net_wd, wl],
|
366 |
+
[optim_g, optim_d, optim_dur_disc, optim_wd],
|
367 |
+
[scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
|
368 |
scaler,
|
369 |
[train_loader, eval_loader],
|
370 |
logger,
|
|
|
376 |
local_rank,
|
377 |
epoch,
|
378 |
hps,
|
379 |
+
[net_g, net_d, net_dur_disc, net_wd, wl],
|
380 |
+
[optim_g, optim_d, optim_dur_disc, optim_wd],
|
381 |
+
[scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
|
382 |
scaler,
|
383 |
[train_loader, None],
|
384 |
None,
|
|
|
386 |
)
|
387 |
scheduler_g.step()
|
388 |
scheduler_d.step()
|
389 |
+
scheduler_wd.step()
|
390 |
if net_dur_disc is not None:
|
391 |
scheduler_dur_disc.step()
|
392 |
|
|
|
404 |
logger,
|
405 |
writers,
|
406 |
):
|
407 |
+
net_g, net_d, net_dur_disc, net_wd, wl = nets
|
408 |
+
optim_g, optim_d, optim_dur_disc, optim_wd = optims
|
409 |
+
scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
|
410 |
train_loader, eval_loader = loaders
|
411 |
if writers is not None:
|
412 |
writer, writer_eval = writers
|
|
|
416 |
|
417 |
net_g.train()
|
418 |
net_d.train()
|
419 |
+
net_wd.train()
|
420 |
if net_dur_disc is not None:
|
421 |
net_dur_disc.train()
|
422 |
for batch_idx, (
|
|
|
432 |
bert,
|
433 |
ja_bert,
|
434 |
en_bert,
|
|
|
435 |
) in enumerate(tqdm(train_loader)):
|
436 |
if net_g.module.use_noise_scaled_mas:
|
437 |
current_mas_noise_scale = (
|
|
|
454 |
bert = bert.cuda(local_rank, non_blocking=True)
|
455 |
ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
|
456 |
en_bert = en_bert.cuda(local_rank, non_blocking=True)
|
|
|
457 |
|
458 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
459 |
(
|
460 |
y_hat,
|
461 |
l_length,
|
|
|
464 |
x_mask,
|
465 |
z_mask,
|
466 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
467 |
+
(hidden_x, logw, logw_, logw_sdp),
|
468 |
g,
|
|
|
469 |
) = net_g(
|
470 |
x,
|
471 |
x_lengths,
|
|
|
477 |
bert,
|
478 |
ja_bert,
|
479 |
en_bert,
|
|
|
480 |
)
|
481 |
mel = spec_to_mel_torch(
|
482 |
spec,
|
|
|
490 |
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
491 |
)
|
492 |
y_hat_mel = mel_spectrogram_torch(
|
493 |
+
y_hat.squeeze(1).float(),
|
494 |
hps.data.filter_length,
|
495 |
hps.data.n_mel_channels,
|
496 |
hps.data.sampling_rate,
|
|
|
506 |
|
507 |
# Discriminator
|
508 |
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
509 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
510 |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
511 |
y_d_hat_r, y_d_hat_g
|
512 |
)
|
|
|
515 |
y_dur_hat_r, y_dur_hat_g = net_dur_disc(
|
516 |
hidden_x.detach(),
|
517 |
x_mask.detach(),
|
518 |
+
logw_.detach(),
|
519 |
logw.detach(),
|
520 |
+
g.detach(),
|
521 |
+
)
|
522 |
+
y_dur_hat_r_sdp, y_dur_hat_g_sdp = net_dur_disc(
|
523 |
+
hidden_x.detach(),
|
524 |
+
x_mask.detach(),
|
525 |
logw_.detach(),
|
526 |
+
logw_sdp.detach(),
|
527 |
g.detach(),
|
528 |
)
|
529 |
+
y_dur_hat_r = y_dur_hat_r + y_dur_hat_r_sdp
|
530 |
+
y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
|
531 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
532 |
# TODO: I think need to mean using the mask, but for now, just mean all
|
533 |
(
|
534 |
loss_dur_disc,
|
|
|
539 |
optim_dur_disc.zero_grad()
|
540 |
scaler.scale(loss_dur_disc_all).backward()
|
541 |
scaler.unscale_(optim_dur_disc)
|
542 |
+
# torch.nn.utils.clip_grad_norm_(
|
543 |
+
# parameters=net_dur_disc.parameters(), max_norm=100
|
544 |
+
# )
|
545 |
+
grad_norm_dur = commons.clip_grad_value_(
|
546 |
+
net_dur_disc.parameters(), None
|
547 |
+
)
|
548 |
scaler.step(optim_dur_disc)
|
549 |
|
550 |
optim_d.zero_grad()
|
551 |
scaler.scale(loss_disc_all).backward()
|
552 |
scaler.unscale_(optim_d)
|
553 |
+
if getattr(hps.train, "bf16_run", False):
|
554 |
+
torch.nn.utils.clip_grad_norm_(parameters=net_d.parameters(), max_norm=200)
|
555 |
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
556 |
scaler.step(optim_d)
|
557 |
|
558 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
559 |
+
loss_slm = wl.discriminator(
|
560 |
+
y.detach().squeeze(), y_hat.detach().squeeze()
|
561 |
+
).mean()
|
562 |
+
|
563 |
+
optim_wd.zero_grad()
|
564 |
+
scaler.scale(loss_slm).backward()
|
565 |
+
scaler.unscale_(optim_wd)
|
566 |
+
# torch.nn.utils.clip_grad_norm_(parameters=net_wd.parameters(), max_norm=200)
|
567 |
+
grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None)
|
568 |
+
scaler.step(optim_wd)
|
569 |
+
|
570 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
571 |
# Generator
|
572 |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
573 |
if net_dur_disc is not None:
|
574 |
+
_, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw, g)
|
575 |
+
_, y_dur_hat_g_sdp = net_dur_disc(hidden_x, x_mask, logw_, logw_sdp, g)
|
576 |
+
y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
|
577 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
578 |
loss_dur = torch.sum(l_length.float())
|
579 |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
580 |
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
581 |
|
582 |
loss_fm = feature_loss(fmap_r, fmap_g)
|
583 |
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
584 |
+
|
585 |
+
loss_lm = wl(y.detach().squeeze(), y_hat.squeeze()).mean()
|
586 |
+
loss_lm_gen = wl.generator(y_hat.squeeze())
|
587 |
+
|
588 |
loss_gen_all = (
|
589 |
+
loss_gen
|
590 |
+
+ loss_fm
|
591 |
+
+ loss_mel
|
592 |
+
+ loss_dur
|
593 |
+
+ loss_kl
|
594 |
+
+ loss_lm
|
595 |
+
+ loss_lm_gen
|
596 |
)
|
597 |
if net_dur_disc is not None:
|
598 |
loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
|
|
|
600 |
optim_g.zero_grad()
|
601 |
scaler.scale(loss_gen_all).backward()
|
602 |
scaler.unscale_(optim_g)
|
603 |
+
if getattr(hps.train, "bf16_run", False):
|
604 |
+
torch.nn.utils.clip_grad_norm_(parameters=net_g.parameters(), max_norm=500)
|
605 |
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
606 |
scaler.step(optim_g)
|
607 |
scaler.update()
|
|
|
620 |
scalar_dict = {
|
621 |
"loss/g/total": loss_gen_all,
|
622 |
"loss/d/total": loss_disc_all,
|
623 |
+
"loss/wd/total": loss_slm,
|
624 |
"learning_rate": lr,
|
625 |
"grad_norm_d": grad_norm_d,
|
626 |
"grad_norm_g": grad_norm_g,
|
627 |
+
"grad_norm_dur": grad_norm_dur,
|
628 |
+
"grad_norm_wd": grad_norm_wd,
|
629 |
}
|
630 |
scalar_dict.update(
|
631 |
{
|
|
|
633 |
"loss/g/mel": loss_mel,
|
634 |
"loss/g/dur": loss_dur,
|
635 |
"loss/g/kl": loss_kl,
|
636 |
+
"loss/g/lm": loss_lm,
|
637 |
+
"loss/g/lm_gen": loss_lm_gen,
|
638 |
}
|
639 |
)
|
640 |
scalar_dict.update(
|
|
|
647 |
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
648 |
)
|
649 |
|
650 |
+
if net_dur_disc is not None:
|
651 |
+
scalar_dict.update({"loss/dur_disc/total": loss_dur_disc_all})
|
652 |
+
|
653 |
+
scalar_dict.update(
|
654 |
+
{
|
655 |
+
"loss/dur_disc_g/{}".format(i): v
|
656 |
+
for i, v in enumerate(losses_dur_disc_g)
|
657 |
+
}
|
658 |
+
)
|
659 |
+
scalar_dict.update(
|
660 |
+
{
|
661 |
+
"loss/dur_disc_r/{}".format(i): v
|
662 |
+
for i, v in enumerate(losses_dur_disc_r)
|
663 |
+
}
|
664 |
+
)
|
665 |
+
|
666 |
+
scalar_dict.update({"loss/g/dur_gen": loss_dur_gen})
|
667 |
+
scalar_dict.update(
|
668 |
+
{
|
669 |
+
"loss/g/dur_gen_{}".format(i): v
|
670 |
+
for i, v in enumerate(losses_dur_gen)
|
671 |
+
}
|
672 |
+
)
|
673 |
+
|
674 |
image_dict = {
|
675 |
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
676 |
y_mel[0].data.cpu().numpy()
|
|
|
708 |
epoch,
|
709 |
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
710 |
)
|
711 |
+
utils.save_checkpoint(
|
712 |
+
net_wd,
|
713 |
+
optim_wd,
|
714 |
+
hps.train.learning_rate,
|
715 |
+
epoch,
|
716 |
+
os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)),
|
717 |
+
)
|
718 |
if net_dur_disc is not None:
|
719 |
utils.save_checkpoint(
|
720 |
net_dur_disc,
|
|
|
733 |
|
734 |
global_step += 1
|
735 |
|
736 |
+
# gc.collect()
|
737 |
+
# torch.cuda.empty_cache()
|
738 |
if rank == 0:
|
739 |
logger.info("====> Epoch: {}".format(epoch))
|
740 |
|
|
|
758 |
bert,
|
759 |
ja_bert,
|
760 |
en_bert,
|
|
|
761 |
) in enumerate(eval_loader):
|
762 |
x, x_lengths = x.cuda(), x_lengths.cuda()
|
763 |
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
|
|
768 |
en_bert = en_bert.cuda()
|
769 |
tone = tone.cuda()
|
770 |
language = language.cuda()
|
|
|
771 |
for use_sdp in [True, False]:
|
772 |
y_hat, attn, mask, *_ = generator.module.infer(
|
773 |
x,
|
|
|
778 |
bert,
|
779 |
ja_bert,
|
780 |
en_bert,
|
|
|
781 |
y=spec,
|
782 |
max_len=1000,
|
783 |
sdp_ratio=0.0 if not use_sdp else 1.0,
|
utils.py
CHANGED
@@ -301,7 +301,11 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
|
|
301 |
|
302 |
to_del = [
|
303 |
os.path.join(path_to_models, fn)
|
304 |
-
for fn in (
|
|
|
|
|
|
|
|
|
305 |
]
|
306 |
|
307 |
def del_info(fn):
|
|
|
301 |
|
302 |
to_del = [
|
303 |
os.path.join(path_to_models, fn)
|
304 |
+
for fn in (
|
305 |
+
x_sorted("G")[:-n_ckpts_to_keep]
|
306 |
+
+ x_sorted("D")[:-n_ckpts_to_keep]
|
307 |
+
+ x_sorted("WD")[:-n_ckpts_to_keep]
|
308 |
+
)
|
309 |
]
|
310 |
|
311 |
def del_info(fn):
|
webui.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# flake8: noqa: E402
|
|
|
2 |
import os
|
3 |
import logging
|
4 |
import re_matching
|
@@ -32,6 +33,14 @@ if device == "mps":
|
|
32 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def generate_audio(
|
36 |
slices,
|
37 |
sdp_ratio,
|
@@ -42,15 +51,20 @@ def generate_audio(
|
|
42 |
language,
|
43 |
reference_audio,
|
44 |
emotion,
|
|
|
|
|
45 |
skip_start=False,
|
46 |
skip_end=False,
|
47 |
):
|
48 |
audio_list = []
|
49 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
|
|
|
|
|
|
50 |
with torch.no_grad():
|
51 |
for idx, piece in enumerate(slices):
|
52 |
-
skip_start =
|
53 |
-
skip_end =
|
54 |
audio = infer(
|
55 |
piece,
|
56 |
reference_audio=reference_audio,
|
@@ -66,10 +80,11 @@ def generate_audio(
|
|
66 |
device=device,
|
67 |
skip_start=skip_start,
|
68 |
skip_end=skip_end,
|
|
|
|
|
69 |
)
|
70 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
71 |
audio_list.append(audio16bit)
|
72 |
-
# audio_list.append(silence) # 将静音添加到列表中
|
73 |
return audio_list
|
74 |
|
75 |
|
@@ -88,10 +103,13 @@ def generate_audio_multilang(
|
|
88 |
):
|
89 |
audio_list = []
|
90 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
|
|
|
|
|
|
91 |
with torch.no_grad():
|
92 |
for idx, piece in enumerate(slices):
|
93 |
-
skip_start =
|
94 |
-
skip_end =
|
95 |
audio = infer_multilang(
|
96 |
piece,
|
97 |
reference_audio=reference_audio,
|
@@ -110,7 +128,6 @@ def generate_audio_multilang(
|
|
110 |
)
|
111 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
112 |
audio_list.append(audio16bit)
|
113 |
-
# audio_list.append(silence) # 将静音添加到列表中
|
114 |
return audio_list
|
115 |
|
116 |
|
@@ -127,63 +144,50 @@ def tts_split(
|
|
127 |
interval_between_sent,
|
128 |
reference_audio,
|
129 |
emotion,
|
|
|
|
|
130 |
):
|
131 |
-
if language == "mix":
|
132 |
-
return ("invalid", None)
|
133 |
while text.find("\n\n") != -1:
|
134 |
text = text.replace("\n\n", "\n")
|
|
|
135 |
para_list = re_matching.cut_para(text)
|
|
|
136 |
audio_list = []
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
skip_end = idx != len(para_list) - 1
|
141 |
-
audio = infer(
|
142 |
p,
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
device=device,
|
154 |
-
skip_start=skip_start,
|
155 |
-
skip_end=skip_end,
|
156 |
)
|
157 |
-
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
158 |
-
audio_list.append(audio16bit)
|
159 |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
160 |
audio_list.append(silence)
|
161 |
-
|
162 |
-
for idx, p in enumerate(para_list):
|
163 |
-
skip_start = idx != 0
|
164 |
-
skip_end = idx != len(para_list) - 1
|
165 |
audio_list_sent = []
|
166 |
sent_list = re_matching.cut_sent(p)
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
audio = infer(
|
171 |
s,
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
device=device,
|
183 |
-
skip_start=skip_start,
|
184 |
-
skip_end=skip_end,
|
185 |
)
|
186 |
-
audio_list_sent.append(audio)
|
187 |
silence = np.zeros((int)(44100 * interval_between_sent))
|
188 |
audio_list_sent.append(silence)
|
189 |
if (interval_between_para - interval_between_sent) > 0:
|
@@ -196,10 +200,49 @@ def tts_split(
|
|
196 |
) # 对完整句子做音量归一
|
197 |
audio_list.append(audio16bit)
|
198 |
audio_concat = np.concatenate(audio_list)
|
199 |
-
return ("Success", (
|
200 |
|
201 |
|
202 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
text: str,
|
204 |
speaker,
|
205 |
sdp_ratio,
|
@@ -209,15 +252,9 @@ def tts_fn(
|
|
209 |
language,
|
210 |
reference_audio,
|
211 |
emotion,
|
212 |
-
|
|
|
213 |
):
|
214 |
-
if prompt_mode == "Audio prompt":
|
215 |
-
if reference_audio == None:
|
216 |
-
return ("Invalid audio prompt", None)
|
217 |
-
else:
|
218 |
-
reference_audio = load_audio(reference_audio)[1]
|
219 |
-
else:
|
220 |
-
reference_audio = None
|
221 |
audio_list = []
|
222 |
if language == "mix":
|
223 |
bool_valid, str_valid = re_matching.validate_text(text)
|
@@ -226,120 +263,40 @@ def tts_fn(
|
|
226 |
hps.data.sampling_rate,
|
227 |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
228 |
)
|
229 |
-
result = []
|
230 |
for slice in re_matching.text_matching(text):
|
231 |
-
_speaker = slice
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
temp_lang += temp_
|
247 |
-
else:
|
248 |
-
if len(temp_contant) == 0:
|
249 |
-
temp_contant.append([])
|
250 |
-
temp_lang.append([])
|
251 |
-
temp_contant[-1].append(content)
|
252 |
-
temp_lang[-1].append(lang)
|
253 |
-
for i, j in zip(temp_lang, temp_contant):
|
254 |
-
result.append([*zip(i, j), _speaker])
|
255 |
-
for i, one in enumerate(result):
|
256 |
-
skip_start = i != 0
|
257 |
-
skip_end = i != len(result) - 1
|
258 |
-
_speaker = one.pop()
|
259 |
-
idx = 0
|
260 |
-
while idx < len(one):
|
261 |
-
text_to_generate = []
|
262 |
-
lang_to_generate = []
|
263 |
-
while True:
|
264 |
-
lang, content = one[idx]
|
265 |
-
temp_text = [content]
|
266 |
-
if len(text_to_generate) > 0:
|
267 |
-
text_to_generate[-1] += [temp_text.pop(0)]
|
268 |
-
lang_to_generate[-1] += [lang]
|
269 |
-
if len(temp_text) > 0:
|
270 |
-
text_to_generate += [[i] for i in temp_text]
|
271 |
-
lang_to_generate += [[lang]] * len(temp_text)
|
272 |
-
if idx + 1 < len(one):
|
273 |
-
idx += 1
|
274 |
-
else:
|
275 |
-
break
|
276 |
-
skip_start = (idx != 0) and skip_start
|
277 |
-
skip_end = (idx != len(one) - 1) and skip_end
|
278 |
-
print(text_to_generate, lang_to_generate)
|
279 |
-
audio_list.extend(
|
280 |
-
generate_audio_multilang(
|
281 |
-
text_to_generate,
|
282 |
-
sdp_ratio,
|
283 |
-
noise_scale,
|
284 |
-
noise_scale_w,
|
285 |
-
length_scale,
|
286 |
-
_speaker,
|
287 |
-
lang_to_generate,
|
288 |
-
reference_audio,
|
289 |
-
emotion,
|
290 |
-
skip_start,
|
291 |
-
skip_end,
|
292 |
-
)
|
293 |
)
|
294 |
-
|
295 |
elif language.lower() == "auto":
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
)
|
304 |
-
|
305 |
-
while idx < len(sentences_list):
|
306 |
-
text_to_generate = []
|
307 |
-
lang_to_generate = []
|
308 |
-
while True:
|
309 |
-
content, lang = sentences_list[idx]
|
310 |
-
temp_text = [content]
|
311 |
-
lang = lang.upper()
|
312 |
-
if lang == "JA":
|
313 |
-
lang = "JP"
|
314 |
-
if len(text_to_generate) > 0:
|
315 |
-
text_to_generate[-1] += [temp_text.pop(0)]
|
316 |
-
lang_to_generate[-1] += [lang]
|
317 |
-
if len(temp_text) > 0:
|
318 |
-
text_to_generate += [[i] for i in temp_text]
|
319 |
-
lang_to_generate += [[lang]] * len(temp_text)
|
320 |
-
if idx + 1 < len(sentences_list):
|
321 |
-
idx += 1
|
322 |
-
else:
|
323 |
-
break
|
324 |
-
skip_start = (idx != 0) and skip_start
|
325 |
-
skip_end = (idx != len(sentences_list) - 1) and skip_end
|
326 |
-
print(text_to_generate, lang_to_generate)
|
327 |
-
audio_list.extend(
|
328 |
-
generate_audio_multilang(
|
329 |
-
text_to_generate,
|
330 |
-
sdp_ratio,
|
331 |
-
noise_scale,
|
332 |
-
noise_scale_w,
|
333 |
-
length_scale,
|
334 |
-
speaker,
|
335 |
-
lang_to_generate,
|
336 |
-
reference_audio,
|
337 |
-
emotion,
|
338 |
-
skip_start,
|
339 |
-
skip_end,
|
340 |
-
)
|
341 |
-
)
|
342 |
-
idx += 1
|
343 |
else:
|
344 |
audio_list.extend(
|
345 |
generate_audio(
|
@@ -352,13 +309,65 @@ def tts_fn(
|
|
352 |
language,
|
353 |
reference_audio,
|
354 |
emotion,
|
|
|
|
|
355 |
)
|
356 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
audio_concat = np.concatenate(audio_list)
|
359 |
return "Success", (hps.data.sampling_rate, audio_concat)
|
360 |
|
361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
def load_audio(path):
|
363 |
audio, sr = librosa.load(path, 48000)
|
364 |
# audio = librosa.resample(audio, 44100, 48000)
|
@@ -408,34 +417,37 @@ if __name__ == "__main__":
|
|
408 |
)
|
409 |
trans = gr.Button("中翻日", variant="primary")
|
410 |
slicer = gr.Button("快速切分", variant="primary")
|
|
|
411 |
speaker = gr.Dropdown(
|
412 |
choices=speakers, value=speakers[0], label="Speaker"
|
413 |
)
|
414 |
_ = gr.Markdown(
|
415 |
-
value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n"
|
|
|
416 |
)
|
417 |
prompt_mode = gr.Radio(
|
418 |
["Text prompt", "Audio prompt"],
|
419 |
label="Prompt Mode",
|
420 |
value="Text prompt",
|
|
|
421 |
)
|
422 |
text_prompt = gr.Textbox(
|
423 |
label="Text prompt",
|
424 |
placeholder="用文字描述生成风格。如:Happy",
|
425 |
value="Happy",
|
426 |
-
visible=
|
427 |
)
|
428 |
audio_prompt = gr.Audio(
|
429 |
label="Audio prompt", type="filepath", visible=False
|
430 |
)
|
431 |
sdp_ratio = gr.Slider(
|
432 |
-
minimum=0, maximum=1, value=0.
|
433 |
)
|
434 |
noise_scale = gr.Slider(
|
435 |
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
|
436 |
)
|
437 |
noise_scale_w = gr.Slider(
|
438 |
-
minimum=0.1, maximum=2, value=0.
|
439 |
)
|
440 |
length_scale = gr.Slider(
|
441 |
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
|
@@ -445,6 +457,21 @@ if __name__ == "__main__":
|
|
445 |
)
|
446 |
btn = gr.Button("生成音频!", variant="primary")
|
447 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
with gr.Row():
|
449 |
with gr.Column():
|
450 |
interval_between_sent = gr.Slider(
|
@@ -487,6 +514,8 @@ if __name__ == "__main__":
|
|
487 |
audio_prompt,
|
488 |
text_prompt,
|
489 |
prompt_mode,
|
|
|
|
|
490 |
],
|
491 |
outputs=[text_output, audio_output],
|
492 |
)
|
@@ -511,6 +540,8 @@ if __name__ == "__main__":
|
|
511 |
interval_between_sent,
|
512 |
audio_prompt,
|
513 |
text_prompt,
|
|
|
|
|
514 |
],
|
515 |
outputs=[text_output, audio_output],
|
516 |
)
|
@@ -527,6 +558,12 @@ if __name__ == "__main__":
|
|
527 |
outputs=[audio_prompt],
|
528 |
)
|
529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
print("推理页面已开启!")
|
531 |
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
|
532 |
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
|
|
|
1 |
# flake8: noqa: E402
|
2 |
+
import gc
|
3 |
import os
|
4 |
import logging
|
5 |
import re_matching
|
|
|
33 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
34 |
|
35 |
|
36 |
+
def free_up_memory():
|
37 |
+
# Prior inference run might have large variables not cleaned up due to exception during the run.
|
38 |
+
# Free up as much memory as possible to allow this run to be successful.
|
39 |
+
gc.collect()
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
|
43 |
+
|
44 |
def generate_audio(
|
45 |
slices,
|
46 |
sdp_ratio,
|
|
|
51 |
language,
|
52 |
reference_audio,
|
53 |
emotion,
|
54 |
+
style_text,
|
55 |
+
style_weight,
|
56 |
skip_start=False,
|
57 |
skip_end=False,
|
58 |
):
|
59 |
audio_list = []
|
60 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
61 |
+
|
62 |
+
free_up_memory()
|
63 |
+
|
64 |
with torch.no_grad():
|
65 |
for idx, piece in enumerate(slices):
|
66 |
+
skip_start = idx != 0
|
67 |
+
skip_end = idx != len(slices) - 1
|
68 |
audio = infer(
|
69 |
piece,
|
70 |
reference_audio=reference_audio,
|
|
|
80 |
device=device,
|
81 |
skip_start=skip_start,
|
82 |
skip_end=skip_end,
|
83 |
+
style_text=style_text,
|
84 |
+
style_weight=style_weight,
|
85 |
)
|
86 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
87 |
audio_list.append(audio16bit)
|
|
|
88 |
return audio_list
|
89 |
|
90 |
|
|
|
103 |
):
|
104 |
audio_list = []
|
105 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
106 |
+
|
107 |
+
free_up_memory()
|
108 |
+
|
109 |
with torch.no_grad():
|
110 |
for idx, piece in enumerate(slices):
|
111 |
+
skip_start = idx != 0
|
112 |
+
skip_end = idx != len(slices) - 1
|
113 |
audio = infer_multilang(
|
114 |
piece,
|
115 |
reference_audio=reference_audio,
|
|
|
128 |
)
|
129 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
130 |
audio_list.append(audio16bit)
|
|
|
131 |
return audio_list
|
132 |
|
133 |
|
|
|
144 |
interval_between_sent,
|
145 |
reference_audio,
|
146 |
emotion,
|
147 |
+
style_text,
|
148 |
+
style_weight,
|
149 |
):
|
|
|
|
|
150 |
while text.find("\n\n") != -1:
|
151 |
text = text.replace("\n\n", "\n")
|
152 |
+
text = text.replace("|", "")
|
153 |
para_list = re_matching.cut_para(text)
|
154 |
+
para_list = [p for p in para_list if p != ""]
|
155 |
audio_list = []
|
156 |
+
for p in para_list:
|
157 |
+
if not cut_by_sent:
|
158 |
+
audio_list += process_text(
|
|
|
|
|
159 |
p,
|
160 |
+
speaker,
|
161 |
+
sdp_ratio,
|
162 |
+
noise_scale,
|
163 |
+
noise_scale_w,
|
164 |
+
length_scale,
|
165 |
+
language,
|
166 |
+
reference_audio,
|
167 |
+
emotion,
|
168 |
+
style_text,
|
169 |
+
style_weight,
|
|
|
|
|
|
|
170 |
)
|
|
|
|
|
171 |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
172 |
audio_list.append(silence)
|
173 |
+
else:
|
|
|
|
|
|
|
174 |
audio_list_sent = []
|
175 |
sent_list = re_matching.cut_sent(p)
|
176 |
+
sent_list = [s for s in sent_list if s != ""]
|
177 |
+
for s in sent_list:
|
178 |
+
audio_list_sent += process_text(
|
|
|
179 |
s,
|
180 |
+
speaker,
|
181 |
+
sdp_ratio,
|
182 |
+
noise_scale,
|
183 |
+
noise_scale_w,
|
184 |
+
length_scale,
|
185 |
+
language,
|
186 |
+
reference_audio,
|
187 |
+
emotion,
|
188 |
+
style_text,
|
189 |
+
style_weight,
|
|
|
|
|
|
|
190 |
)
|
|
|
191 |
silence = np.zeros((int)(44100 * interval_between_sent))
|
192 |
audio_list_sent.append(silence)
|
193 |
if (interval_between_para - interval_between_sent) > 0:
|
|
|
200 |
) # 对完整句子做音量归一
|
201 |
audio_list.append(audio16bit)
|
202 |
audio_concat = np.concatenate(audio_list)
|
203 |
+
return ("Success", (hps.data.sampling_rate, audio_concat))
|
204 |
|
205 |
|
206 |
+
def process_mix(slice):
|
207 |
+
_speaker = slice.pop()
|
208 |
+
_text, _lang = [], []
|
209 |
+
for lang, content in slice:
|
210 |
+
content = content.split("|")
|
211 |
+
content = [part for part in content if part != ""]
|
212 |
+
if len(content) == 0:
|
213 |
+
continue
|
214 |
+
if len(_text) == 0:
|
215 |
+
_text = [[part] for part in content]
|
216 |
+
_lang = [[lang] for part in content]
|
217 |
+
else:
|
218 |
+
_text[-1].append(content[0])
|
219 |
+
_lang[-1].append(lang)
|
220 |
+
if len(content) > 1:
|
221 |
+
_text += [[part] for part in content[1:]]
|
222 |
+
_lang += [[lang] for part in content[1:]]
|
223 |
+
return _text, _lang, _speaker
|
224 |
+
|
225 |
+
|
226 |
+
def process_auto(text):
|
227 |
+
_text, _lang = [], []
|
228 |
+
for slice in text.split("|"):
|
229 |
+
if slice == "":
|
230 |
+
continue
|
231 |
+
temp_text, temp_lang = [], []
|
232 |
+
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
|
233 |
+
for sentence, lang in sentences_list:
|
234 |
+
if sentence == "":
|
235 |
+
continue
|
236 |
+
temp_text.append(sentence)
|
237 |
+
if lang == "ja":
|
238 |
+
lang = "jp"
|
239 |
+
temp_lang.append(lang.upper())
|
240 |
+
_text.append(temp_text)
|
241 |
+
_lang.append(temp_lang)
|
242 |
+
return _text, _lang
|
243 |
+
|
244 |
+
|
245 |
+
def process_text(
|
246 |
text: str,
|
247 |
speaker,
|
248 |
sdp_ratio,
|
|
|
252 |
language,
|
253 |
reference_audio,
|
254 |
emotion,
|
255 |
+
style_text=None,
|
256 |
+
style_weight=0,
|
257 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
audio_list = []
|
259 |
if language == "mix":
|
260 |
bool_valid, str_valid = re_matching.validate_text(text)
|
|
|
263 |
hps.data.sampling_rate,
|
264 |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
265 |
)
|
|
|
266 |
for slice in re_matching.text_matching(text):
|
267 |
+
_text, _lang, _speaker = process_mix(slice)
|
268 |
+
if _speaker is None:
|
269 |
+
continue
|
270 |
+
print(f"Text: {_text}\nLang: {_lang}")
|
271 |
+
audio_list.extend(
|
272 |
+
generate_audio_multilang(
|
273 |
+
_text,
|
274 |
+
sdp_ratio,
|
275 |
+
noise_scale,
|
276 |
+
noise_scale_w,
|
277 |
+
length_scale,
|
278 |
+
_speaker,
|
279 |
+
_lang,
|
280 |
+
reference_audio,
|
281 |
+
emotion,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
)
|
283 |
+
)
|
284 |
elif language.lower() == "auto":
|
285 |
+
_text, _lang = process_auto(text)
|
286 |
+
print(f"Text: {_text}\nLang: {_lang}")
|
287 |
+
audio_list.extend(
|
288 |
+
generate_audio_multilang(
|
289 |
+
_text,
|
290 |
+
sdp_ratio,
|
291 |
+
noise_scale,
|
292 |
+
noise_scale_w,
|
293 |
+
length_scale,
|
294 |
+
speaker,
|
295 |
+
_lang,
|
296 |
+
reference_audio,
|
297 |
+
emotion,
|
298 |
)
|
299 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
else:
|
301 |
audio_list.extend(
|
302 |
generate_audio(
|
|
|
309 |
language,
|
310 |
reference_audio,
|
311 |
emotion,
|
312 |
+
style_text,
|
313 |
+
style_weight,
|
314 |
)
|
315 |
)
|
316 |
+
return audio_list
|
317 |
+
|
318 |
+
|
319 |
+
def tts_fn(
|
320 |
+
text: str,
|
321 |
+
speaker,
|
322 |
+
sdp_ratio,
|
323 |
+
noise_scale,
|
324 |
+
noise_scale_w,
|
325 |
+
length_scale,
|
326 |
+
language,
|
327 |
+
reference_audio,
|
328 |
+
emotion,
|
329 |
+
prompt_mode,
|
330 |
+
style_text=None,
|
331 |
+
style_weight=0,
|
332 |
+
):
|
333 |
+
if style_text == "":
|
334 |
+
style_text = None
|
335 |
+
if prompt_mode == "Audio prompt":
|
336 |
+
if reference_audio == None:
|
337 |
+
return ("Invalid audio prompt", None)
|
338 |
+
else:
|
339 |
+
reference_audio = load_audio(reference_audio)[1]
|
340 |
+
else:
|
341 |
+
reference_audio = None
|
342 |
+
|
343 |
+
audio_list = process_text(
|
344 |
+
text,
|
345 |
+
speaker,
|
346 |
+
sdp_ratio,
|
347 |
+
noise_scale,
|
348 |
+
noise_scale_w,
|
349 |
+
length_scale,
|
350 |
+
language,
|
351 |
+
reference_audio,
|
352 |
+
emotion,
|
353 |
+
style_text,
|
354 |
+
style_weight,
|
355 |
+
)
|
356 |
|
357 |
audio_concat = np.concatenate(audio_list)
|
358 |
return "Success", (hps.data.sampling_rate, audio_concat)
|
359 |
|
360 |
|
361 |
+
def format_utils(text, speaker):
|
362 |
+
_text, _lang = process_auto(text)
|
363 |
+
res = f"[{speaker}]"
|
364 |
+
for lang_s, content_s in zip(_lang, _text):
|
365 |
+
for lang, content in zip(lang_s, content_s):
|
366 |
+
res += f"<{lang.lower()}>{content}"
|
367 |
+
res += "|"
|
368 |
+
return "mix", res[:-1]
|
369 |
+
|
370 |
+
|
371 |
def load_audio(path):
|
372 |
audio, sr = librosa.load(path, 48000)
|
373 |
# audio = librosa.resample(audio, 44100, 48000)
|
|
|
417 |
)
|
418 |
trans = gr.Button("中翻日", variant="primary")
|
419 |
slicer = gr.Button("快速切分", variant="primary")
|
420 |
+
formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
|
421 |
speaker = gr.Dropdown(
|
422 |
choices=speakers, value=speakers[0], label="Speaker"
|
423 |
)
|
424 |
_ = gr.Markdown(
|
425 |
+
value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
|
426 |
+
visible=False,
|
427 |
)
|
428 |
prompt_mode = gr.Radio(
|
429 |
["Text prompt", "Audio prompt"],
|
430 |
label="Prompt Mode",
|
431 |
value="Text prompt",
|
432 |
+
visible=False,
|
433 |
)
|
434 |
text_prompt = gr.Textbox(
|
435 |
label="Text prompt",
|
436 |
placeholder="用文字描述生成风格。如:Happy",
|
437 |
value="Happy",
|
438 |
+
visible=False,
|
439 |
)
|
440 |
audio_prompt = gr.Audio(
|
441 |
label="Audio prompt", type="filepath", visible=False
|
442 |
)
|
443 |
sdp_ratio = gr.Slider(
|
444 |
+
minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio"
|
445 |
)
|
446 |
noise_scale = gr.Slider(
|
447 |
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
|
448 |
)
|
449 |
noise_scale_w = gr.Slider(
|
450 |
+
minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W"
|
451 |
)
|
452 |
length_scale = gr.Slider(
|
453 |
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
|
|
|
457 |
)
|
458 |
btn = gr.Button("生成音频!", variant="primary")
|
459 |
with gr.Column():
|
460 |
+
with gr.Accordion("融合文本语义", open=False):
|
461 |
+
gr.Markdown(
|
462 |
+
value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
|
463 |
+
"**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
|
464 |
+
"效果较不明确,留空即为不使用该功能"
|
465 |
+
)
|
466 |
+
style_text = gr.Textbox(label="辅助文本")
|
467 |
+
style_weight = gr.Slider(
|
468 |
+
minimum=0,
|
469 |
+
maximum=1,
|
470 |
+
value=0.7,
|
471 |
+
step=0.1,
|
472 |
+
label="Weight",
|
473 |
+
info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
|
474 |
+
)
|
475 |
with gr.Row():
|
476 |
with gr.Column():
|
477 |
interval_between_sent = gr.Slider(
|
|
|
514 |
audio_prompt,
|
515 |
text_prompt,
|
516 |
prompt_mode,
|
517 |
+
style_text,
|
518 |
+
style_weight,
|
519 |
],
|
520 |
outputs=[text_output, audio_output],
|
521 |
)
|
|
|
540 |
interval_between_sent,
|
541 |
audio_prompt,
|
542 |
text_prompt,
|
543 |
+
style_text,
|
544 |
+
style_weight,
|
545 |
],
|
546 |
outputs=[text_output, audio_output],
|
547 |
)
|
|
|
558 |
outputs=[audio_prompt],
|
559 |
)
|
560 |
|
561 |
+
formatter.click(
|
562 |
+
format_utils,
|
563 |
+
inputs=[text, speaker],
|
564 |
+
outputs=[language, text],
|
565 |
+
)
|
566 |
+
|
567 |
print("推理页面已开启!")
|
568 |
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
|
569 |
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
|
webui_preprocess.py
CHANGED
@@ -19,9 +19,9 @@ def generate_config(data_dir, batch_size):
|
|
19 |
assert data_dir != "", "数据集名称不能为空"
|
20 |
start_path, _, train_path, val_path, config_path = get_path(data_dir)
|
21 |
if os.path.isfile(config_path):
|
22 |
-
config = json.load(open(config_path))
|
23 |
else:
|
24 |
-
config = json.load(open("configs/config.json"))
|
25 |
config["data"]["training_files"] = train_path
|
26 |
config["data"]["validation_files"] = val_path
|
27 |
config["train"]["batch_size"] = batch_size
|
@@ -44,7 +44,7 @@ def resample(data_dir):
|
|
44 |
in_dir = os.path.join(start_path, "raw")
|
45 |
out_dir = os.path.join(start_path, "wavs")
|
46 |
subprocess.run(
|
47 |
-
f"python
|
48 |
f"--sr 44100 "
|
49 |
f"--in_dir {in_dir} "
|
50 |
f"--out_dir {out_dir} ",
|
@@ -60,7 +60,9 @@ def preprocess_text(data_dir):
|
|
60 |
with open(lbl_path, "w", encoding="utf-8") as f:
|
61 |
for line in lines:
|
62 |
path, spk, language, text = line.strip().split("|")
|
63 |
-
path = os.path.join(start_path, "wavs", os.path.basename(path))
|
|
|
|
|
64 |
f.writelines(f"{path}|{spk}|{language}|{text}\n")
|
65 |
subprocess.run(
|
66 |
f"python preprocess_text.py "
|
@@ -83,16 +85,6 @@ def bert_gen(data_dir):
|
|
83 |
return "BERT 特征文件生成完成"
|
84 |
|
85 |
|
86 |
-
def clap_gen(data_dir):
|
87 |
-
assert data_dir != "", "数据集名称不能为空"
|
88 |
-
_, _, _, _, config_path = get_path(data_dir)
|
89 |
-
subprocess.run(
|
90 |
-
f"python clap_gen.py " f"--config {config_path}",
|
91 |
-
shell=True,
|
92 |
-
)
|
93 |
-
return "CLAP 特征文件生成完成"
|
94 |
-
|
95 |
-
|
96 |
if __name__ == "__main__":
|
97 |
with gr.Blocks() as app:
|
98 |
with gr.Row():
|
@@ -100,13 +92,13 @@ if __name__ == "__main__":
|
|
100 |
_ = gr.Markdown(
|
101 |
value="# Bert-VITS2 数据预处理\n"
|
102 |
"## 预先准备:\n"
|
103 |
-
"下载 BERT 和
|
104 |
"- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
|
105 |
"- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
|
106 |
"- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
|
107 |
-
"- [
|
108 |
"\n"
|
109 |
-
"将 BERT 模型放置到 `bert` 文件夹下,
|
110 |
"\n"
|
111 |
"数据准备:\n"
|
112 |
"将数据放置在 data 文件夹下,按照如下结构组织:\n"
|
@@ -156,12 +148,10 @@ if __name__ == "__main__":
|
|
156 |
preprocess_text_btn = gr.Button(value="执行", variant="primary")
|
157 |
_ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
|
158 |
bert_gen_btn = gr.Button(value="执行", variant="primary")
|
159 |
-
_ = gr.Markdown(value="## 第五步:生成 CLAP 特征文件")
|
160 |
-
clap_gen_btn = gr.Button(value="执行", variant="primary")
|
161 |
_ = gr.Markdown(
|
162 |
value="## 训练模型及部署:\n"
|
163 |
"修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
|
164 |
-
"- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
|
165 |
"- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
|
166 |
)
|
167 |
|
@@ -171,7 +161,6 @@ if __name__ == "__main__":
|
|
171 |
resample_btn.click(resample, inputs=[data_dir], outputs=[info])
|
172 |
preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
|
173 |
bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
|
174 |
-
clap_gen_btn.click(clap_gen, inputs=[data_dir], outputs=[info])
|
175 |
|
176 |
webbrowser.open("http://127.0.0.1:7860")
|
177 |
app.launch(share=False, server_port=7860)
|
|
|
19 |
assert data_dir != "", "数据集名称不能为空"
|
20 |
start_path, _, train_path, val_path, config_path = get_path(data_dir)
|
21 |
if os.path.isfile(config_path):
|
22 |
+
config = json.load(open(config_path, "r", encoding="utf-8"))
|
23 |
else:
|
24 |
+
config = json.load(open("configs/config.json", "r", encoding="utf-8"))
|
25 |
config["data"]["training_files"] = train_path
|
26 |
config["data"]["validation_files"] = val_path
|
27 |
config["train"]["batch_size"] = batch_size
|
|
|
44 |
in_dir = os.path.join(start_path, "raw")
|
45 |
out_dir = os.path.join(start_path, "wavs")
|
46 |
subprocess.run(
|
47 |
+
f"python resample_legacy.py "
|
48 |
f"--sr 44100 "
|
49 |
f"--in_dir {in_dir} "
|
50 |
f"--out_dir {out_dir} ",
|
|
|
60 |
with open(lbl_path, "w", encoding="utf-8") as f:
|
61 |
for line in lines:
|
62 |
path, spk, language, text = line.strip().split("|")
|
63 |
+
path = os.path.join(start_path, "wavs", os.path.basename(path)).replace(
|
64 |
+
"\\", "/"
|
65 |
+
)
|
66 |
f.writelines(f"{path}|{spk}|{language}|{text}\n")
|
67 |
subprocess.run(
|
68 |
f"python preprocess_text.py "
|
|
|
85 |
return "BERT 特征文件生成完成"
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if __name__ == "__main__":
|
89 |
with gr.Blocks() as app:
|
90 |
with gr.Row():
|
|
|
92 |
_ = gr.Markdown(
|
93 |
value="# Bert-VITS2 数据预处理\n"
|
94 |
"## 预先准备:\n"
|
95 |
+
"下载 BERT 和 WavLM 模型:\n"
|
96 |
"- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
|
97 |
"- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
|
98 |
"- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
|
99 |
+
"- [WavLM](https://huggingface.co/microsoft/wavlm-base-plus)\n"
|
100 |
"\n"
|
101 |
+
"将 BERT 模型放置到 `bert` 文件夹下,WavLM 模型放置到 `slm` 文件夹下,覆盖同名文件夹。\n"
|
102 |
"\n"
|
103 |
"数据准备:\n"
|
104 |
"将数据放置在 data 文件夹下,按照如下结构组织:\n"
|
|
|
148 |
preprocess_text_btn = gr.Button(value="执行", variant="primary")
|
149 |
_ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
|
150 |
bert_gen_btn = gr.Button(value="执行", variant="primary")
|
|
|
|
|
151 |
_ = gr.Markdown(
|
152 |
value="## 训练模型及部署:\n"
|
153 |
"修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
|
154 |
+
"- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth`、`WD_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
|
155 |
"- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
|
156 |
)
|
157 |
|
|
|
161 |
resample_btn.click(resample, inputs=[data_dir], outputs=[info])
|
162 |
preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
|
163 |
bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
|
|
|
164 |
|
165 |
webbrowser.open("http://127.0.0.1:7860")
|
166 |
app.launch(share=False, server_port=7860)
|