update
Browse files- examples/evaluation/step_1_run_evaluation.py +9 -3
- examples/evaluation/step_2_show_metrics.py +65 -40
- examples/evaluation/step_3_show_vad.py +12 -14
- examples/fsmn_vad_by_webrtcvad/step_4_train_model.py +1 -1
- examples/silero_vad_by_webrtcvad/run.sh +1 -1
- examples/silero_vad_by_webrtcvad/step_4_train_model.py +5 -3
- examples/silero_vad_by_webrtcvad/step_5_export_model.py +1 -1
- log.py +45 -8
- main.py +15 -7
- toolbox/pydub/volume.py +39 -0
- toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py +3 -2
- toolbox/torchaudio/models/vad/native_silero_vad/__init__.py +6 -0
- toolbox/torchaudio/models/vad/native_silero_vad/inference_native_silero_vad_onnx.py +198 -0
- toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py +1 -4
examples/evaluation/step_1_run_evaluation.py
CHANGED
@@ -26,7 +26,14 @@ def get_args():
|
|
26 |
)
|
27 |
parser.add_argument(
|
28 |
"--output_file",
|
29 |
-
default=r"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
type=str
|
31 |
)
|
32 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
@@ -110,8 +117,7 @@ def main():
|
|
110 |
min_silence_length=6,
|
111 |
max_speech_length=100000,
|
112 |
min_speech_length=15,
|
113 |
-
|
114 |
-
engine="silero-vad-by-webrtcvad-nx2-dns3",
|
115 |
api_name="/when_click_vad_button"
|
116 |
)
|
117 |
js = json.loads(message)
|
|
|
26 |
)
|
27 |
parser.add_argument(
|
28 |
"--output_file",
|
29 |
+
default=r"native_silero_vad.jsonl",
|
30 |
+
type=str
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--vad_engine",
|
34 |
+
# default="fsmn-vad-by-webrtcvad-nx2-dns3",
|
35 |
+
# default="silero-vad-by-webrtcvad-nx2-dns3",
|
36 |
+
default="native_silero_vad",
|
37 |
type=str
|
38 |
)
|
39 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
|
|
117 |
min_silence_length=6,
|
118 |
max_speech_length=100000,
|
119 |
min_speech_length=15,
|
120 |
+
engine=args.vad_engine,
|
|
|
121 |
api_name="/when_click_vad_button"
|
122 |
)
|
123 |
js = json.loads(message)
|
examples/evaluation/step_2_show_metrics.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
import argparse
|
4 |
import json
|
5 |
import os
|
|
|
6 |
import sys
|
7 |
|
8 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
@@ -16,53 +17,77 @@ def get_args():
|
|
16 |
|
17 |
parser.add_argument(
|
18 |
"--eval_file",
|
19 |
-
default=r"
|
20 |
type=str
|
21 |
)
|
22 |
args = parser.parse_args()
|
23 |
return args
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def main():
|
27 |
-
args = get_args()
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
return
|
67 |
|
68 |
|
|
|
3 |
import argparse
|
4 |
import json
|
5 |
import os
|
6 |
+
from pathlib import Path
|
7 |
import sys
|
8 |
|
9 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
|
|
17 |
|
18 |
parser.add_argument(
|
19 |
"--eval_file",
|
20 |
+
# default=r"native_silero_vad.jsonl",
|
21 |
type=str
|
22 |
)
|
23 |
args = parser.parse_args()
|
24 |
return args
|
25 |
|
26 |
|
27 |
+
evaluation_files = [
|
28 |
+
"native_silero_vad.jsonl",
|
29 |
+
"fsmn-vad.jsonl",
|
30 |
+
"silero-vad.jsonl"
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
def main():
|
35 |
+
# args = get_args()
|
36 |
+
|
37 |
+
for eval_file in evaluation_files:
|
38 |
+
eval_file = Path(eval_file)
|
39 |
+
total = 0
|
40 |
+
total_duration = 0
|
41 |
+
total_accuracy = 0
|
42 |
+
total_precision = 0
|
43 |
+
total_recall = 0
|
44 |
+
total_f1 = 0
|
45 |
+
|
46 |
+
average_accuracy = 0
|
47 |
+
average_precision = 0
|
48 |
+
average_recall = 0
|
49 |
+
average_f1 = 0
|
50 |
+
|
51 |
+
# progress_bar = tqdm(desc=eval_file.name)
|
52 |
+
with open(eval_file.as_posix(), "r", encoding="utf-8") as f:
|
53 |
+
for row in f:
|
54 |
+
row = json.loads(row)
|
55 |
+
duration = row["duration"]
|
56 |
+
accuracy = row["accuracy"]
|
57 |
+
precision = row["precision"]
|
58 |
+
recall = row["recall"]
|
59 |
+
f1 = row["f1"]
|
60 |
+
|
61 |
+
total += 1
|
62 |
+
total_duration += duration
|
63 |
+
total_accuracy += accuracy * duration
|
64 |
+
total_precision += precision * duration
|
65 |
+
total_recall += recall * duration
|
66 |
+
total_f1 += f1 * duration
|
67 |
+
|
68 |
+
average_accuracy = total_accuracy / total_duration
|
69 |
+
average_precision = total_precision / total_duration
|
70 |
+
average_recall = total_recall / total_duration
|
71 |
+
average_f1 = total_f1 / total_duration
|
72 |
+
|
73 |
+
# progress_bar.update(1)
|
74 |
+
# progress_bar.set_postfix({
|
75 |
+
# "total": total,
|
76 |
+
# "accuracy": average_accuracy,
|
77 |
+
# "precision": average_precision,
|
78 |
+
# "recall": average_recall,
|
79 |
+
# "f1": average_f1,
|
80 |
+
# "total_duration": f"{round(total_duration / 60, 4)}min",
|
81 |
+
# })
|
82 |
+
summary = (f"{eval_file.name}, "
|
83 |
+
f"total: {total}, "
|
84 |
+
f"accuracy: {average_accuracy}, "
|
85 |
+
f"precision: {average_precision}, "
|
86 |
+
f"recall: {average_recall}, "
|
87 |
+
f"f1: {average_f1}, "
|
88 |
+
f"total_duration: {f"{round(total_duration / 60, 4)}min"}, "
|
89 |
+
)
|
90 |
+
print(summary)
|
91 |
return
|
92 |
|
93 |
|
examples/evaluation/step_3_show_vad.py
CHANGED
@@ -51,10 +51,17 @@ def show_image(signal: np.ndarray,
|
|
51 |
plt.show()
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def main():
|
55 |
-
args = get_args()
|
56 |
|
57 |
-
with open(
|
58 |
for row in f:
|
59 |
row = json.loads(row)
|
60 |
filename = row["filename"]
|
@@ -77,25 +84,16 @@ def main():
|
|
77 |
begin = int(begin * sample_rate)
|
78 |
end = int(end * sample_rate)
|
79 |
ground_truth_probs[begin:end] = 1
|
|
|
80 |
prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
|
81 |
for begin, end in prediction:
|
82 |
begin = int(begin * sample_rate)
|
83 |
end = int(end * sample_rate)
|
84 |
prediction_probs[begin:end] = 1
|
85 |
|
86 |
-
# p = encoder_num_layers * (encoder_kernel_size - 1) // 2 * hop_size * sample_rate
|
87 |
-
p = 3 * (3 - 1) // 2 * 80
|
88 |
-
p = int(p)
|
89 |
-
print(f"p: {p}")
|
90 |
-
prediction_probs = np.concat(
|
91 |
-
[
|
92 |
-
prediction_probs[p:], prediction_probs[-p:]
|
93 |
-
],
|
94 |
-
axis=-1
|
95 |
-
)
|
96 |
-
|
97 |
show_image(signal,
|
98 |
-
ground_truth_probs,
|
|
|
99 |
sample_rate=sample_rate,
|
100 |
)
|
101 |
return
|
|
|
51 |
plt.show()
|
52 |
|
53 |
|
54 |
+
evaluation_files = [
|
55 |
+
# "native_silero_vad.jsonl",
|
56 |
+
"fsmn-vad.jsonl",
|
57 |
+
"silero-vad.jsonl"
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
def main():
|
62 |
+
# args = get_args()
|
63 |
|
64 |
+
with open(evaluation_files[0], "r", encoding="utf-8") as f:
|
65 |
for row in f:
|
66 |
row = json.loads(row)
|
67 |
filename = row["filename"]
|
|
|
84 |
begin = int(begin * sample_rate)
|
85 |
end = int(end * sample_rate)
|
86 |
ground_truth_probs[begin:end] = 1
|
87 |
+
|
88 |
prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
|
89 |
for begin, end in prediction:
|
90 |
begin = int(begin * sample_rate)
|
91 |
end = int(end * sample_rate)
|
92 |
prediction_probs[begin:end] = 1
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
show_image(signal,
|
95 |
+
ground_truth_probs,
|
96 |
+
prediction_probs,
|
97 |
sample_rate=sample_rate,
|
98 |
)
|
99 |
return
|
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -127,7 +127,7 @@ def main():
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
130 |
-
do_volume_enhancement=
|
131 |
# skip=225000,
|
132 |
)
|
133 |
valid_dataset = VadPaddingJsonlDataset(
|
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
130 |
+
do_volume_enhancement=False,
|
131 |
# skip=225000,
|
132 |
)
|
133 |
valid_dataset = VadPaddingJsonlDataset(
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
|
5 |
bash run.sh --stage 3 --stop_stage 5 --system_version centos \
|
6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
-
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
8 |
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
9 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
10 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
|
|
4 |
|
5 |
bash run.sh --stage 3 --stop_stage 5 --system_version centos \
|
6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
+
--final_model_name silero-vad-by-webrtcvad-nx2-dns3-20250813 \
|
8 |
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
9 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
10 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -127,7 +127,7 @@ def main():
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
130 |
-
do_volume_enhancement=
|
131 |
# skip=225000,
|
132 |
)
|
133 |
valid_dataset = VadPaddingJsonlDataset(
|
@@ -271,7 +271,8 @@ def main():
|
|
271 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
272 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
273 |
|
274 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
|
|
275 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
276 |
logger.info(f"find nan or inf in loss. continue.")
|
277 |
continue
|
@@ -352,7 +353,8 @@ def main():
|
|
352 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
353 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
354 |
|
355 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
|
|
356 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
357 |
logger.info(f"find nan or inf in loss. continue.")
|
358 |
continue
|
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
130 |
+
do_volume_enhancement=False,
|
131 |
# skip=225000,
|
132 |
)
|
133 |
valid_dataset = VadPaddingJsonlDataset(
|
|
|
271 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
272 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
273 |
|
274 |
+
# loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
275 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 1.0 * lsnr_loss
|
276 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
277 |
logger.info(f"find nan or inf in loss. continue.")
|
278 |
continue
|
|
|
353 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
354 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
355 |
|
356 |
+
# loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
357 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 1.0 * lsnr_loss
|
358 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
359 |
logger.info(f"find nan or inf in loss. continue.")
|
360 |
continue
|
examples/silero_vad_by_webrtcvad/step_5_export_model.py
CHANGED
@@ -94,7 +94,7 @@ def main():
|
|
94 |
"new_lstm_hidden_state": {2: "batch_size"},
|
95 |
})
|
96 |
|
97 |
-
ort_session = ort.InferenceSession("
|
98 |
input_feed = {
|
99 |
"inputs": inputs.numpy(),
|
100 |
"encoder_in_cache": encoder_in_cache.numpy(),
|
|
|
94 |
"new_lstm_hidden_state": {2: "batch_size"},
|
95 |
})
|
96 |
|
97 |
+
ort_session = ort.InferenceSession("model.onnx")
|
98 |
input_feed = {
|
99 |
"inputs": inputs.numpy(),
|
100 |
"encoder_in_cache": encoder_in_cache.numpy(),
|
log.py
CHANGED
@@ -15,8 +15,43 @@ def get_converter(tz_info: str = "Asia/Shanghai"):
|
|
15 |
return converter
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
19 |
-
fmt = "%(asctime)s
|
20 |
|
21 |
formatter = logging.Formatter(
|
22 |
fmt=fmt,
|
@@ -38,11 +73,12 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
38 |
backupCount=2,
|
39 |
)
|
40 |
main_info_file_handler.setLevel(logging.INFO)
|
41 |
-
main_info_file_handler.setFormatter(
|
42 |
main_logger.addHandler(main_info_file_handler)
|
43 |
|
44 |
# http
|
45 |
http_logger = logging.getLogger("http")
|
|
|
46 |
http_file_handler = RotatingFileHandler(
|
47 |
filename=os.path.join(log_directory, "http.log"),
|
48 |
maxBytes=100*1024*1024, # 100MB
|
@@ -50,11 +86,12 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
50 |
backupCount=2,
|
51 |
)
|
52 |
http_file_handler.setLevel(logging.DEBUG)
|
53 |
-
http_file_handler.setFormatter(
|
54 |
http_logger.addHandler(http_file_handler)
|
55 |
|
56 |
# api
|
57 |
api_logger = logging.getLogger("api")
|
|
|
58 |
api_file_handler = RotatingFileHandler(
|
59 |
filename=os.path.join(log_directory, "api.log"),
|
60 |
maxBytes=10*1024*1024, # 10MB
|
@@ -62,7 +99,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
62 |
backupCount=2,
|
63 |
)
|
64 |
api_file_handler.setLevel(logging.DEBUG)
|
65 |
-
api_file_handler.setFormatter(
|
66 |
api_logger.addHandler(api_file_handler)
|
67 |
|
68 |
# alarm
|
@@ -74,7 +111,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
74 |
backupCount=2,
|
75 |
)
|
76 |
alarm_file_handler.setLevel(logging.DEBUG)
|
77 |
-
alarm_file_handler.setFormatter(
|
78 |
alarm_logger.addHandler(alarm_file_handler)
|
79 |
|
80 |
debug_file_handler = RotatingFileHandler(
|
@@ -84,7 +121,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
84 |
backupCount=2,
|
85 |
)
|
86 |
debug_file_handler.setLevel(logging.DEBUG)
|
87 |
-
debug_file_handler.setFormatter(
|
88 |
|
89 |
info_file_handler = RotatingFileHandler(
|
90 |
filename=os.path.join(log_directory, "info.log"),
|
@@ -93,7 +130,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
93 |
backupCount=2,
|
94 |
)
|
95 |
info_file_handler.setLevel(logging.INFO)
|
96 |
-
info_file_handler.setFormatter(
|
97 |
|
98 |
error_file_handler = RotatingFileHandler(
|
99 |
filename=os.path.join(log_directory, "error.log"),
|
@@ -102,7 +139,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
102 |
backupCount=2,
|
103 |
)
|
104 |
error_file_handler.setLevel(logging.ERROR)
|
105 |
-
error_file_handler.setFormatter(
|
106 |
|
107 |
logging.basicConfig(
|
108 |
level=logging.DEBUG,
|
|
|
15 |
return converter
|
16 |
|
17 |
|
18 |
+
def setup_stream(tz_info: str = "Asia/Shanghai"):
|
19 |
+
fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
|
20 |
+
|
21 |
+
formatter = logging.Formatter(
|
22 |
+
fmt=fmt,
|
23 |
+
datefmt="%Y-%m-%d %H:%M:%S %z"
|
24 |
+
)
|
25 |
+
formatter.converter = get_converter(tz_info)
|
26 |
+
|
27 |
+
stream_handler = logging.StreamHandler()
|
28 |
+
stream_handler.setLevel(logging.INFO)
|
29 |
+
stream_handler.setFormatter(formatter)
|
30 |
+
|
31 |
+
# main
|
32 |
+
main_logger = logging.getLogger("main")
|
33 |
+
main_logger.addHandler(stream_handler)
|
34 |
+
|
35 |
+
# http
|
36 |
+
http_logger = logging.getLogger("http")
|
37 |
+
http_logger.addHandler(stream_handler)
|
38 |
+
|
39 |
+
# api
|
40 |
+
api_logger = logging.getLogger("api")
|
41 |
+
api_logger.addHandler(stream_handler)
|
42 |
+
|
43 |
+
logging.basicConfig(
|
44 |
+
level=logging.DEBUG,
|
45 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
46 |
+
handlers=[
|
47 |
+
|
48 |
+
]
|
49 |
+
)
|
50 |
+
return
|
51 |
+
|
52 |
+
|
53 |
def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
54 |
+
fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
|
55 |
|
56 |
formatter = logging.Formatter(
|
57 |
fmt=fmt,
|
|
|
73 |
backupCount=2,
|
74 |
)
|
75 |
main_info_file_handler.setLevel(logging.INFO)
|
76 |
+
main_info_file_handler.setFormatter(formatter)
|
77 |
main_logger.addHandler(main_info_file_handler)
|
78 |
|
79 |
# http
|
80 |
http_logger = logging.getLogger("http")
|
81 |
+
http_logger.addHandler(stream_handler)
|
82 |
http_file_handler = RotatingFileHandler(
|
83 |
filename=os.path.join(log_directory, "http.log"),
|
84 |
maxBytes=100*1024*1024, # 100MB
|
|
|
86 |
backupCount=2,
|
87 |
)
|
88 |
http_file_handler.setLevel(logging.DEBUG)
|
89 |
+
http_file_handler.setFormatter(formatter)
|
90 |
http_logger.addHandler(http_file_handler)
|
91 |
|
92 |
# api
|
93 |
api_logger = logging.getLogger("api")
|
94 |
+
api_logger.addHandler(stream_handler)
|
95 |
api_file_handler = RotatingFileHandler(
|
96 |
filename=os.path.join(log_directory, "api.log"),
|
97 |
maxBytes=10*1024*1024, # 10MB
|
|
|
99 |
backupCount=2,
|
100 |
)
|
101 |
api_file_handler.setLevel(logging.DEBUG)
|
102 |
+
api_file_handler.setFormatter(formatter)
|
103 |
api_logger.addHandler(api_file_handler)
|
104 |
|
105 |
# alarm
|
|
|
111 |
backupCount=2,
|
112 |
)
|
113 |
alarm_file_handler.setLevel(logging.DEBUG)
|
114 |
+
alarm_file_handler.setFormatter(formatter)
|
115 |
alarm_logger.addHandler(alarm_file_handler)
|
116 |
|
117 |
debug_file_handler = RotatingFileHandler(
|
|
|
121 |
backupCount=2,
|
122 |
)
|
123 |
debug_file_handler.setLevel(logging.DEBUG)
|
124 |
+
debug_file_handler.setFormatter(formatter)
|
125 |
|
126 |
info_file_handler = RotatingFileHandler(
|
127 |
filename=os.path.join(log_directory, "info.log"),
|
|
|
130 |
backupCount=2,
|
131 |
)
|
132 |
info_file_handler.setLevel(logging.INFO)
|
133 |
+
info_file_handler.setFormatter(formatter)
|
134 |
|
135 |
error_file_handler = RotatingFileHandler(
|
136 |
filename=os.path.join(log_directory, "error.log"),
|
|
|
139 |
backupCount=2,
|
140 |
)
|
141 |
error_file_handler.setLevel(logging.ERROR)
|
142 |
+
error_file_handler.setFormatter(formatter)
|
143 |
|
144 |
logging.basicConfig(
|
145 |
level=logging.DEBUG,
|
main.py
CHANGED
@@ -25,8 +25,10 @@ from project_settings import environment, project_path, log_directory, time_zone
|
|
25 |
from toolbox.os.command import Command
|
26 |
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
|
27 |
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
|
|
|
28 |
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
29 |
from toolbox.vad.utils import PostProcess
|
|
|
30 |
|
31 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
32 |
|
@@ -93,9 +95,11 @@ def shell(cmd: str):
|
|
93 |
|
94 |
|
95 |
def get_infer_cls_by_model_name(model_name: str):
|
96 |
-
if model_name.__contains__("
|
|
|
|
|
97 |
infer_cls = InferenceFSMNVadOnnx
|
98 |
-
elif model_name.__contains__("silero"):
|
99 |
infer_cls = InferenceSileroVad
|
100 |
else:
|
101 |
raise AssertionError
|
@@ -158,8 +162,8 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
|
|
158 |
vad_info = infer_engine.infer(audio)
|
159 |
time_cost = time.time() - begin
|
160 |
|
161 |
-
probs = vad_info["probs"]
|
162 |
-
lsnr = vad_info["lsnr"]
|
163 |
# lsnr = lsnr / np.max(np.abs(lsnr))
|
164 |
lsnr = lsnr / 30
|
165 |
|
@@ -197,13 +201,17 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
|
|
197 |
] for v in vad_segments
|
198 |
]
|
199 |
|
|
|
|
|
|
|
200 |
# message
|
201 |
rtf = time_cost / audio_duration
|
202 |
info = {
|
203 |
"vad_segments": vad_segments,
|
204 |
"time_cost": round(time_cost, 4),
|
205 |
"duration": round(audio_duration, 4),
|
206 |
-
"rtf": round(rtf, 4)
|
|
|
207 |
}
|
208 |
message = json.dumps(info, ensure_ascii=False, indent=4)
|
209 |
|
@@ -239,8 +247,8 @@ def main():
|
|
239 |
}
|
240 |
for filename in (project_path / "trained_models").glob("*.zip")
|
241 |
if filename.name not in (
|
242 |
-
"cnn-vad-by-webrtcvad-nx-dns3.zip",
|
243 |
-
"fsmn-vad-by-webrtcvad-nx-dns3.zip",
|
244 |
"examples.zip",
|
245 |
"sound-2-ch32.zip",
|
246 |
"sound-3-ch32.zip",
|
|
|
25 |
from toolbox.os.command import Command
|
26 |
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
|
27 |
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
|
28 |
+
from toolbox.torchaudio.models.vad.native_silero_vad.inference_native_silero_vad_onnx import InferenceNativeSileroVadOnnx
|
29 |
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
30 |
from toolbox.vad.utils import PostProcess
|
31 |
+
from toolbox.pydub.volume import get_volume
|
32 |
|
33 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
34 |
|
|
|
95 |
|
96 |
|
97 |
def get_infer_cls_by_model_name(model_name: str):
|
98 |
+
if model_name.__contains__("native_silero_vad"):
|
99 |
+
infer_cls = InferenceNativeSileroVadOnnx
|
100 |
+
elif model_name.__contains__("fsmn-vad"):
|
101 |
infer_cls = InferenceFSMNVadOnnx
|
102 |
+
elif model_name.__contains__("silero-vad"):
|
103 |
infer_cls = InferenceSileroVad
|
104 |
else:
|
105 |
raise AssertionError
|
|
|
162 |
vad_info = infer_engine.infer(audio)
|
163 |
time_cost = time.time() - begin
|
164 |
|
165 |
+
probs: np.ndarray = vad_info["probs"]
|
166 |
+
lsnr: np.ndarray = vad_info["lsnr"]
|
167 |
# lsnr = lsnr / np.max(np.abs(lsnr))
|
168 |
lsnr = lsnr / 30
|
169 |
|
|
|
201 |
] for v in vad_segments
|
202 |
]
|
203 |
|
204 |
+
# volume
|
205 |
+
volume_map: dict = get_volume(audio, sample_rate)
|
206 |
+
|
207 |
# message
|
208 |
rtf = time_cost / audio_duration
|
209 |
info = {
|
210 |
"vad_segments": vad_segments,
|
211 |
"time_cost": round(time_cost, 4),
|
212 |
"duration": round(audio_duration, 4),
|
213 |
+
"rtf": round(rtf, 4),
|
214 |
+
**volume_map
|
215 |
}
|
216 |
message = json.dumps(info, ensure_ascii=False, indent=4)
|
217 |
|
|
|
247 |
}
|
248 |
for filename in (project_path / "trained_models").glob("*.zip")
|
249 |
if filename.name not in (
|
250 |
+
# "cnn-vad-by-webrtcvad-nx-dns3.zip",
|
251 |
+
# "fsmn-vad-by-webrtcvad-nx-dns3.zip",
|
252 |
"examples.zip",
|
253 |
"sound-2-ch32.zip",
|
254 |
"sound-3-ch32.zip",
|
toolbox/pydub/volume.py
CHANGED
@@ -76,6 +76,45 @@ def set_volume(waveform: np.ndarray, sample_rate: int = 8000, volume: int = 0):
|
|
76 |
return samples
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def get_args():
|
80 |
parser = argparse.ArgumentParser()
|
81 |
parser.add_argument(
|
|
|
76 |
return samples
|
77 |
|
78 |
|
79 |
+
def get_volume(waveform: np.ndarray, sample_rate: int = 8000):
|
80 |
+
if np.min(waveform) < -1 or np.max(waveform) > 1:
|
81 |
+
raise AssertionError(f"waveform type: {type(waveform)}, dtype: {waveform.dtype}")
|
82 |
+
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
83 |
+
raw_data = waveform.tobytes()
|
84 |
+
|
85 |
+
audio_segment = AudioSegment(
|
86 |
+
data=raw_data,
|
87 |
+
sample_width=2,
|
88 |
+
frame_rate=sample_rate,
|
89 |
+
channels=1
|
90 |
+
)
|
91 |
+
|
92 |
+
map_list = [
|
93 |
+
[0, -150],
|
94 |
+
[10, -40],
|
95 |
+
[50, -12],
|
96 |
+
[75, -6],
|
97 |
+
[100, 0],
|
98 |
+
]
|
99 |
+
scores = [a for a, b in map_list]
|
100 |
+
stages = [b for a, b in map_list]
|
101 |
+
|
102 |
+
audio_dbfs = audio_segment.dBFS
|
103 |
+
|
104 |
+
# 计算目标 volume
|
105 |
+
volume = score_transform(
|
106 |
+
x=audio_dbfs,
|
107 |
+
stages=list(reversed(stages)),
|
108 |
+
scores=list(reversed(scores)),
|
109 |
+
)
|
110 |
+
|
111 |
+
result = {
|
112 |
+
"dbfs": audio_dbfs,
|
113 |
+
"volume": volume,
|
114 |
+
}
|
115 |
+
return result
|
116 |
+
|
117 |
+
|
118 |
def get_args():
|
119 |
parser = argparse.ArgumentParser()
|
120 |
parser.add_argument(
|
toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py
CHANGED
@@ -139,8 +139,9 @@ class VadPaddingJsonlDataset(IterableDataset):
|
|
139 |
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
140 |
|
141 |
# volume enhancement
|
142 |
-
|
143 |
-
|
|
|
144 |
|
145 |
noise_wave_list = list()
|
146 |
for noise in noise_list:
|
|
|
139 |
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
140 |
|
141 |
# volume enhancement
|
142 |
+
if self.do_volume_enhancement:
|
143 |
+
volume = random.randint(10, 80)
|
144 |
+
speech_wave_np = set_volume(speech_wave_np, sample_rate=self.expected_sample_rate, volume=volume)
|
145 |
|
146 |
noise_wave_list = list()
|
147 |
for noise in noise_list:
|
toolbox/torchaudio/models/vad/native_silero_vad/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/native_silero_vad/inference_native_silero_vad_onnx.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import tempfile
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
from scipy.io import wavfile
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import onnxruntime as ort
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
torch.set_num_threads(1)
|
17 |
+
|
18 |
+
from project_settings import project_path
|
19 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
20 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("toolbox")
|
24 |
+
|
25 |
+
|
26 |
+
class NativeSileroVadConfig(PretrainedConfig):
|
27 |
+
def __init__(self,
|
28 |
+
sample_rate: int = 8000,
|
29 |
+
win_size: int = 256,
|
30 |
+
hop_size: int = 256,
|
31 |
+
**kwargs
|
32 |
+
):
|
33 |
+
super(NativeSileroVadConfig, self).__init__(**kwargs)
|
34 |
+
# transform
|
35 |
+
self.sample_rate = sample_rate
|
36 |
+
self.win_size = win_size
|
37 |
+
self.hop_size = hop_size
|
38 |
+
|
39 |
+
|
40 |
+
class InferenceNativeSileroVadOnnx(object):
|
41 |
+
"""
|
42 |
+
code:
|
43 |
+
https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/utils_vad.py
|
44 |
+
|
45 |
+
model:
|
46 |
+
https://github.com/snakers4/silero-vad/tree/master/src/silero_vad/data
|
47 |
+
"""
|
48 |
+
def __init__(self,
|
49 |
+
pretrained_model_path_or_zip_file: str,
|
50 |
+
device: str = "cpu"
|
51 |
+
):
|
52 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
53 |
+
self.device = torch.device(device)
|
54 |
+
|
55 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
56 |
+
config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
|
57 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
58 |
+
|
59 |
+
self.config = config
|
60 |
+
self.ort_session = ort_session
|
61 |
+
|
62 |
+
def load_models(self, model_path: str):
|
63 |
+
model_path = Path(model_path)
|
64 |
+
if model_path.name.endswith(".zip"):
|
65 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
66 |
+
out_root = Path(tempfile.gettempdir()) / "cc_vad"
|
67 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
68 |
+
f_zip.extractall(path=out_root)
|
69 |
+
model_path = out_root / model_path.stem
|
70 |
+
|
71 |
+
config = NativeSileroVadConfig.from_pretrained(
|
72 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
73 |
+
)
|
74 |
+
|
75 |
+
opts = ort.SessionOptions()
|
76 |
+
opts.inter_op_num_threads = 1
|
77 |
+
opts.intra_op_num_threads = 1
|
78 |
+
|
79 |
+
ort_session = ort.InferenceSession(
|
80 |
+
(model_path / "silero_vad.onnx").as_posix(),
|
81 |
+
sess_options=opts
|
82 |
+
)
|
83 |
+
shutil.rmtree(model_path)
|
84 |
+
return config, ort_session
|
85 |
+
|
86 |
+
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
|
87 |
+
if signal.dim() == 2:
|
88 |
+
signal = torch.unsqueeze(signal, dim=1)
|
89 |
+
_, _, n_samples = signal.shape
|
90 |
+
remainder = (n_samples - self.config.win_size) % self.config.hop_size
|
91 |
+
if remainder > 0:
|
92 |
+
n_samples_pad = self.config.hop_size - remainder
|
93 |
+
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
|
94 |
+
return signal
|
95 |
+
|
96 |
+
def forward_chunk(self, chunk: torch.Tensor, context: torch.Tensor, state: torch.Tensor):
|
97 |
+
# chunk shape: [1, chunk_size]
|
98 |
+
num_samples = 512 if self.config.sample_rate == 16000 else 256
|
99 |
+
if chunk.shape[-1] != num_samples:
|
100 |
+
raise ValueError(f"Provided number of samples is {chunk.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
101 |
+
|
102 |
+
context_size = 64 if self.config.sample_rate == 16000 else 32
|
103 |
+
|
104 |
+
chunk = torch.cat(tensors=[context, chunk], dim=1)
|
105 |
+
input_feed = {
|
106 |
+
"input": chunk.numpy(),
|
107 |
+
"state": state.numpy(),
|
108 |
+
"sr": np.array(self.config.sample_rate, dtype=np.int64)
|
109 |
+
}
|
110 |
+
ort_outs = self.ort_session.run(output_names=None, input_feed=input_feed)
|
111 |
+
vad_flag, state = ort_outs
|
112 |
+
# vad_flag shape: [b, 1]
|
113 |
+
# state shape: [2, b, 128]
|
114 |
+
vad_flag = torch.from_numpy(vad_flag)
|
115 |
+
state = torch.from_numpy(state)
|
116 |
+
context = chunk[..., -context_size:]
|
117 |
+
return vad_flag, context, state
|
118 |
+
|
119 |
+
def infer(self, signal: np.ndarray) -> np.ndarray:
|
120 |
+
# signal shape: [num_samples,], value between -1 and 1.
|
121 |
+
inputs = torch.tensor(signal, dtype=torch.float32)
|
122 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
123 |
+
# inputs shape: [1, num_samples]
|
124 |
+
|
125 |
+
n_samples = inputs.shape[-1]
|
126 |
+
inputs = self.signal_prepare(inputs)
|
127 |
+
# inputs shape: [1, 1, num_samples]
|
128 |
+
inputs = torch.squeeze(inputs, dim=1)
|
129 |
+
# inputs shape: [1, num_samples]
|
130 |
+
_, num_samples = inputs.shape
|
131 |
+
|
132 |
+
vad_flags = list()
|
133 |
+
|
134 |
+
context = torch.zeros(0)
|
135 |
+
state = torch.zeros(size=(2, 1, 128), dtype=torch.float32)
|
136 |
+
for i in range(0, num_samples, self.config.hop_size):
|
137 |
+
sub_inputs = inputs[:, i:i+self.config.win_size]
|
138 |
+
vad_flag, context, state = self.forward_chunk(sub_inputs, context, state)
|
139 |
+
vad_flags.append(vad_flag)
|
140 |
+
|
141 |
+
vad_flags = torch.cat(vad_flags, dim=1).cpu()
|
142 |
+
# vad_flags, torch.Tensor, shape: [b, num_chunks]
|
143 |
+
vad_flags = vad_flags.numpy()
|
144 |
+
# vad_flags, np.ndarray, shape: [b, num_chunks]
|
145 |
+
vad_flags = vad_flags[0]
|
146 |
+
# vad_flags shape: [num_chunk,]
|
147 |
+
|
148 |
+
result = {
|
149 |
+
"probs": vad_flags,
|
150 |
+
"lsnr": np.zeros_like(vad_flags),
|
151 |
+
}
|
152 |
+
return result
|
153 |
+
|
154 |
+
|
155 |
+
def get_args():
|
156 |
+
parser = argparse.ArgumentParser()
|
157 |
+
parser.add_argument(
|
158 |
+
"--wav_file",
|
159 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
160 |
+
default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
161 |
+
type=str,
|
162 |
+
)
|
163 |
+
args = parser.parse_args()
|
164 |
+
return args
|
165 |
+
|
166 |
+
|
167 |
+
SAMPLE_RATE = 8000
|
168 |
+
|
169 |
+
|
170 |
+
def main():
|
171 |
+
args = get_args()
|
172 |
+
|
173 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
174 |
+
if SAMPLE_RATE != sample_rate:
|
175 |
+
raise AssertionError
|
176 |
+
signal = signal / (1 << 15)
|
177 |
+
|
178 |
+
infer = InferenceNativeSileroVadOnnx(
|
179 |
+
pretrained_model_path_or_zip_file=(project_path / "trained_models/native_silero_vad.zip").as_posix(),
|
180 |
+
)
|
181 |
+
|
182 |
+
vad_info = infer.infer(signal)
|
183 |
+
speech_probs = vad_info["probs"]
|
184 |
+
# speech_probs, np.ndarray shape: [num_chunk,]
|
185 |
+
|
186 |
+
speech_probs = process_speech_probs(
|
187 |
+
signal=signal,
|
188 |
+
speech_probs=speech_probs,
|
189 |
+
frame_step=infer.config.hop_size,
|
190 |
+
)
|
191 |
+
|
192 |
+
# plot
|
193 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
194 |
+
return
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py
CHANGED
@@ -109,9 +109,6 @@ class InferenceSileroVadOnnx(object):
|
|
109 |
}
|
110 |
return result
|
111 |
|
112 |
-
def post_process(self, probs: List[float]):
|
113 |
-
return
|
114 |
-
|
115 |
|
116 |
def get_args():
|
117 |
parser = argparse.ArgumentParser()
|
@@ -157,7 +154,7 @@ def main():
|
|
157 |
raise AssertionError
|
158 |
signal = signal / (1 << 15)
|
159 |
|
160 |
-
infer =
|
161 |
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
162 |
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
163 |
)
|
|
|
109 |
}
|
110 |
return result
|
111 |
|
|
|
|
|
|
|
112 |
|
113 |
def get_args():
|
114 |
parser = argparse.ArgumentParser()
|
|
|
154 |
raise AssertionError
|
155 |
signal = signal / (1 << 15)
|
156 |
|
157 |
+
infer = InferenceSileroVadOnnx(
|
158 |
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
159 |
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
160 |
)
|