HoneyTian commited on
Commit
d9015be
·
1 Parent(s): ef3c782
examples/evaluation/step_1_run_evaluation.py CHANGED
@@ -26,7 +26,14 @@ def get_args():
26
  )
27
  parser.add_argument(
28
  "--output_file",
29
- default=r"evaluation.jsonl",
 
 
 
 
 
 
 
30
  type=str
31
  )
32
  parser.add_argument("--expected_sample_rate", default=8000, type=int)
@@ -110,8 +117,7 @@ def main():
110
  min_silence_length=6,
111
  max_speech_length=100000,
112
  min_speech_length=15,
113
- # engine="fsmn-vad-by-webrtcvad-nx2-dns3",
114
- engine="silero-vad-by-webrtcvad-nx2-dns3",
115
  api_name="/when_click_vad_button"
116
  )
117
  js = json.loads(message)
 
26
  )
27
  parser.add_argument(
28
  "--output_file",
29
+ default=r"native_silero_vad.jsonl",
30
+ type=str
31
+ )
32
+ parser.add_argument(
33
+ "--vad_engine",
34
+ # default="fsmn-vad-by-webrtcvad-nx2-dns3",
35
+ # default="silero-vad-by-webrtcvad-nx2-dns3",
36
+ default="native_silero_vad",
37
  type=str
38
  )
39
  parser.add_argument("--expected_sample_rate", default=8000, type=int)
 
117
  min_silence_length=6,
118
  max_speech_length=100000,
119
  min_speech_length=15,
120
+ engine=args.vad_engine,
 
121
  api_name="/when_click_vad_button"
122
  )
123
  js = json.loads(message)
examples/evaluation/step_2_show_metrics.py CHANGED
@@ -3,6 +3,7 @@
3
  import argparse
4
  import json
5
  import os
 
6
  import sys
7
 
8
  pwd = os.path.abspath(os.path.dirname(__file__))
@@ -16,53 +17,77 @@ def get_args():
16
 
17
  parser.add_argument(
18
  "--eval_file",
19
- default=r"evaluation.jsonl",
20
  type=str
21
  )
22
  args = parser.parse_args()
23
  return args
24
 
25
 
 
 
 
 
 
 
 
26
  def main():
27
- args = get_args()
28
-
29
- total = 0
30
- total_duration = 0
31
- total_accuracy = 0
32
- total_precision = 0
33
- total_recall = 0
34
- total_f1 = 0
35
- progress_bar = tqdm(desc="evaluation")
36
- with open(args.eval_file, "r", encoding="utf-8") as f:
37
- for row in f:
38
- row = json.loads(row)
39
- duration = row["duration"]
40
- accuracy = row["accuracy"]
41
- precision = row["precision"]
42
- recall = row["recall"]
43
- f1 = row["f1"]
44
-
45
- total += 1
46
- total_duration += duration
47
- total_accuracy += accuracy * duration
48
- total_precision += precision * duration
49
- total_recall += recall * duration
50
- total_f1 += f1 * duration
51
-
52
- average_accuracy = total_accuracy / total_duration
53
- average_precision = total_precision / total_duration
54
- average_recall = total_recall / total_duration
55
- average_f1 = total_f1 / total_duration
56
-
57
- progress_bar.update(1)
58
- progress_bar.set_postfix({
59
- "total": total,
60
- "accuracy": average_accuracy,
61
- "precision": average_precision,
62
- "recall": average_recall,
63
- "f1": average_f1,
64
- "total_duration": f"{round(total_duration / 60, 4)}min",
65
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return
67
 
68
 
 
3
  import argparse
4
  import json
5
  import os
6
+ from pathlib import Path
7
  import sys
8
 
9
  pwd = os.path.abspath(os.path.dirname(__file__))
 
17
 
18
  parser.add_argument(
19
  "--eval_file",
20
+ # default=r"native_silero_vad.jsonl",
21
  type=str
22
  )
23
  args = parser.parse_args()
24
  return args
25
 
26
 
27
+ evaluation_files = [
28
+ "native_silero_vad.jsonl",
29
+ "fsmn-vad.jsonl",
30
+ "silero-vad.jsonl"
31
+ ]
32
+
33
+
34
  def main():
35
+ # args = get_args()
36
+
37
+ for eval_file in evaluation_files:
38
+ eval_file = Path(eval_file)
39
+ total = 0
40
+ total_duration = 0
41
+ total_accuracy = 0
42
+ total_precision = 0
43
+ total_recall = 0
44
+ total_f1 = 0
45
+
46
+ average_accuracy = 0
47
+ average_precision = 0
48
+ average_recall = 0
49
+ average_f1 = 0
50
+
51
+ # progress_bar = tqdm(desc=eval_file.name)
52
+ with open(eval_file.as_posix(), "r", encoding="utf-8") as f:
53
+ for row in f:
54
+ row = json.loads(row)
55
+ duration = row["duration"]
56
+ accuracy = row["accuracy"]
57
+ precision = row["precision"]
58
+ recall = row["recall"]
59
+ f1 = row["f1"]
60
+
61
+ total += 1
62
+ total_duration += duration
63
+ total_accuracy += accuracy * duration
64
+ total_precision += precision * duration
65
+ total_recall += recall * duration
66
+ total_f1 += f1 * duration
67
+
68
+ average_accuracy = total_accuracy / total_duration
69
+ average_precision = total_precision / total_duration
70
+ average_recall = total_recall / total_duration
71
+ average_f1 = total_f1 / total_duration
72
+
73
+ # progress_bar.update(1)
74
+ # progress_bar.set_postfix({
75
+ # "total": total,
76
+ # "accuracy": average_accuracy,
77
+ # "precision": average_precision,
78
+ # "recall": average_recall,
79
+ # "f1": average_f1,
80
+ # "total_duration": f"{round(total_duration / 60, 4)}min",
81
+ # })
82
+ summary = (f"{eval_file.name}, "
83
+ f"total: {total}, "
84
+ f"accuracy: {average_accuracy}, "
85
+ f"precision: {average_precision}, "
86
+ f"recall: {average_recall}, "
87
+ f"f1: {average_f1}, "
88
+ f"total_duration: {f"{round(total_duration / 60, 4)}min"}, "
89
+ )
90
+ print(summary)
91
  return
92
 
93
 
examples/evaluation/step_3_show_vad.py CHANGED
@@ -51,10 +51,17 @@ def show_image(signal: np.ndarray,
51
  plt.show()
52
 
53
 
 
 
 
 
 
 
 
54
  def main():
55
- args = get_args()
56
 
57
- with open(args.eval_file, "r", encoding="utf-8") as f:
58
  for row in f:
59
  row = json.loads(row)
60
  filename = row["filename"]
@@ -77,25 +84,16 @@ def main():
77
  begin = int(begin * sample_rate)
78
  end = int(end * sample_rate)
79
  ground_truth_probs[begin:end] = 1
 
80
  prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
81
  for begin, end in prediction:
82
  begin = int(begin * sample_rate)
83
  end = int(end * sample_rate)
84
  prediction_probs[begin:end] = 1
85
 
86
- # p = encoder_num_layers * (encoder_kernel_size - 1) // 2 * hop_size * sample_rate
87
- p = 3 * (3 - 1) // 2 * 80
88
- p = int(p)
89
- print(f"p: {p}")
90
- prediction_probs = np.concat(
91
- [
92
- prediction_probs[p:], prediction_probs[-p:]
93
- ],
94
- axis=-1
95
- )
96
-
97
  show_image(signal,
98
- ground_truth_probs, prediction_probs,
 
99
  sample_rate=sample_rate,
100
  )
101
  return
 
51
  plt.show()
52
 
53
 
54
+ evaluation_files = [
55
+ # "native_silero_vad.jsonl",
56
+ "fsmn-vad.jsonl",
57
+ "silero-vad.jsonl"
58
+ ]
59
+
60
+
61
  def main():
62
+ # args = get_args()
63
 
64
+ with open(evaluation_files[0], "r", encoding="utf-8") as f:
65
  for row in f:
66
  row = json.loads(row)
67
  filename = row["filename"]
 
84
  begin = int(begin * sample_rate)
85
  end = int(end * sample_rate)
86
  ground_truth_probs[begin:end] = 1
87
+
88
  prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
89
  for begin, end in prediction:
90
  begin = int(begin * sample_rate)
91
  end = int(end * sample_rate)
92
  prediction_probs[begin:end] = 1
93
 
 
 
 
 
 
 
 
 
 
 
 
94
  show_image(signal,
95
+ ground_truth_probs,
96
+ prediction_probs,
97
  sample_rate=sample_rate,
98
  )
99
  return
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -127,7 +127,7 @@ def main():
127
  max_wave_value=32768.0,
128
  min_snr_db=config.min_snr_db,
129
  max_snr_db=config.max_snr_db,
130
- do_volume_enhancement=True,
131
  # skip=225000,
132
  )
133
  valid_dataset = VadPaddingJsonlDataset(
 
127
  max_wave_value=32768.0,
128
  min_snr_db=config.min_snr_db,
129
  max_snr_db=config.max_snr_db,
130
+ do_volume_enhancement=False,
131
  # skip=225000,
132
  )
133
  valid_dataset = VadPaddingJsonlDataset(
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -4,7 +4,7 @@
4
 
5
  bash run.sh --stage 3 --stop_stage 5 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
- --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
  --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
  --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
  /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
 
4
 
5
  bash run.sh --stage 3 --stop_stage 5 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
+ --final_model_name silero-vad-by-webrtcvad-nx2-dns3-20250813 \
8
  --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
  --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
  /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
examples/silero_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -127,7 +127,7 @@ def main():
127
  max_wave_value=32768.0,
128
  min_snr_db=config.min_snr_db,
129
  max_snr_db=config.max_snr_db,
130
- do_volume_enhancement=True,
131
  # skip=225000,
132
  )
133
  valid_dataset = VadPaddingJsonlDataset(
@@ -271,7 +271,8 @@ def main():
271
  dice_loss = dice_loss_fn.forward(probs, targets)
272
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
273
 
274
- loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
 
275
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
276
  logger.info(f"find nan or inf in loss. continue.")
277
  continue
@@ -352,7 +353,8 @@ def main():
352
  dice_loss = dice_loss_fn.forward(probs, targets)
353
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
354
 
355
- loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
 
356
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
357
  logger.info(f"find nan or inf in loss. continue.")
358
  continue
 
127
  max_wave_value=32768.0,
128
  min_snr_db=config.min_snr_db,
129
  max_snr_db=config.max_snr_db,
130
+ do_volume_enhancement=False,
131
  # skip=225000,
132
  )
133
  valid_dataset = VadPaddingJsonlDataset(
 
271
  dice_loss = dice_loss_fn.forward(probs, targets)
272
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
273
 
274
+ # loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
275
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 1.0 * lsnr_loss
276
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
277
  logger.info(f"find nan or inf in loss. continue.")
278
  continue
 
353
  dice_loss = dice_loss_fn.forward(probs, targets)
354
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
355
 
356
+ # loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
357
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 1.0 * lsnr_loss
358
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
359
  logger.info(f"find nan or inf in loss. continue.")
360
  continue
examples/silero_vad_by_webrtcvad/step_5_export_model.py CHANGED
@@ -94,7 +94,7 @@ def main():
94
  "new_lstm_hidden_state": {2: "batch_size"},
95
  })
96
 
97
- ort_session = ort.InferenceSession("silero_vad.onnx")
98
  input_feed = {
99
  "inputs": inputs.numpy(),
100
  "encoder_in_cache": encoder_in_cache.numpy(),
 
94
  "new_lstm_hidden_state": {2: "batch_size"},
95
  })
96
 
97
+ ort_session = ort.InferenceSession("model.onnx")
98
  input_feed = {
99
  "inputs": inputs.numpy(),
100
  "encoder_in_cache": encoder_in_cache.numpy(),
log.py CHANGED
@@ -15,8 +15,43 @@ def get_converter(tz_info: str = "Asia/Shanghai"):
15
  return converter
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
19
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
20
 
21
  formatter = logging.Formatter(
22
  fmt=fmt,
@@ -38,11 +73,12 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
38
  backupCount=2,
39
  )
40
  main_info_file_handler.setLevel(logging.INFO)
41
- main_info_file_handler.setFormatter(logging.Formatter(fmt))
42
  main_logger.addHandler(main_info_file_handler)
43
 
44
  # http
45
  http_logger = logging.getLogger("http")
 
46
  http_file_handler = RotatingFileHandler(
47
  filename=os.path.join(log_directory, "http.log"),
48
  maxBytes=100*1024*1024, # 100MB
@@ -50,11 +86,12 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
50
  backupCount=2,
51
  )
52
  http_file_handler.setLevel(logging.DEBUG)
53
- http_file_handler.setFormatter(logging.Formatter(fmt))
54
  http_logger.addHandler(http_file_handler)
55
 
56
  # api
57
  api_logger = logging.getLogger("api")
 
58
  api_file_handler = RotatingFileHandler(
59
  filename=os.path.join(log_directory, "api.log"),
60
  maxBytes=10*1024*1024, # 10MB
@@ -62,7 +99,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
62
  backupCount=2,
63
  )
64
  api_file_handler.setLevel(logging.DEBUG)
65
- api_file_handler.setFormatter(logging.Formatter(fmt))
66
  api_logger.addHandler(api_file_handler)
67
 
68
  # alarm
@@ -74,7 +111,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
74
  backupCount=2,
75
  )
76
  alarm_file_handler.setLevel(logging.DEBUG)
77
- alarm_file_handler.setFormatter(logging.Formatter(fmt))
78
  alarm_logger.addHandler(alarm_file_handler)
79
 
80
  debug_file_handler = RotatingFileHandler(
@@ -84,7 +121,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
84
  backupCount=2,
85
  )
86
  debug_file_handler.setLevel(logging.DEBUG)
87
- debug_file_handler.setFormatter(logging.Formatter(fmt))
88
 
89
  info_file_handler = RotatingFileHandler(
90
  filename=os.path.join(log_directory, "info.log"),
@@ -93,7 +130,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
93
  backupCount=2,
94
  )
95
  info_file_handler.setLevel(logging.INFO)
96
- info_file_handler.setFormatter(logging.Formatter(fmt))
97
 
98
  error_file_handler = RotatingFileHandler(
99
  filename=os.path.join(log_directory, "error.log"),
@@ -102,7 +139,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
102
  backupCount=2,
103
  )
104
  error_file_handler.setLevel(logging.ERROR)
105
- error_file_handler.setFormatter(logging.Formatter(fmt))
106
 
107
  logging.basicConfig(
108
  level=logging.DEBUG,
 
15
  return converter
16
 
17
 
18
+ def setup_stream(tz_info: str = "Asia/Shanghai"):
19
+ fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
20
+
21
+ formatter = logging.Formatter(
22
+ fmt=fmt,
23
+ datefmt="%Y-%m-%d %H:%M:%S %z"
24
+ )
25
+ formatter.converter = get_converter(tz_info)
26
+
27
+ stream_handler = logging.StreamHandler()
28
+ stream_handler.setLevel(logging.INFO)
29
+ stream_handler.setFormatter(formatter)
30
+
31
+ # main
32
+ main_logger = logging.getLogger("main")
33
+ main_logger.addHandler(stream_handler)
34
+
35
+ # http
36
+ http_logger = logging.getLogger("http")
37
+ http_logger.addHandler(stream_handler)
38
+
39
+ # api
40
+ api_logger = logging.getLogger("api")
41
+ api_logger.addHandler(stream_handler)
42
+
43
+ logging.basicConfig(
44
+ level=logging.DEBUG,
45
+ datefmt="%a, %d %b %Y %H:%M:%S",
46
+ handlers=[
47
+
48
+ ]
49
+ )
50
+ return
51
+
52
+
53
  def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
54
+ fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
55
 
56
  formatter = logging.Formatter(
57
  fmt=fmt,
 
73
  backupCount=2,
74
  )
75
  main_info_file_handler.setLevel(logging.INFO)
76
+ main_info_file_handler.setFormatter(formatter)
77
  main_logger.addHandler(main_info_file_handler)
78
 
79
  # http
80
  http_logger = logging.getLogger("http")
81
+ http_logger.addHandler(stream_handler)
82
  http_file_handler = RotatingFileHandler(
83
  filename=os.path.join(log_directory, "http.log"),
84
  maxBytes=100*1024*1024, # 100MB
 
86
  backupCount=2,
87
  )
88
  http_file_handler.setLevel(logging.DEBUG)
89
+ http_file_handler.setFormatter(formatter)
90
  http_logger.addHandler(http_file_handler)
91
 
92
  # api
93
  api_logger = logging.getLogger("api")
94
+ api_logger.addHandler(stream_handler)
95
  api_file_handler = RotatingFileHandler(
96
  filename=os.path.join(log_directory, "api.log"),
97
  maxBytes=10*1024*1024, # 10MB
 
99
  backupCount=2,
100
  )
101
  api_file_handler.setLevel(logging.DEBUG)
102
+ api_file_handler.setFormatter(formatter)
103
  api_logger.addHandler(api_file_handler)
104
 
105
  # alarm
 
111
  backupCount=2,
112
  )
113
  alarm_file_handler.setLevel(logging.DEBUG)
114
+ alarm_file_handler.setFormatter(formatter)
115
  alarm_logger.addHandler(alarm_file_handler)
116
 
117
  debug_file_handler = RotatingFileHandler(
 
121
  backupCount=2,
122
  )
123
  debug_file_handler.setLevel(logging.DEBUG)
124
+ debug_file_handler.setFormatter(formatter)
125
 
126
  info_file_handler = RotatingFileHandler(
127
  filename=os.path.join(log_directory, "info.log"),
 
130
  backupCount=2,
131
  )
132
  info_file_handler.setLevel(logging.INFO)
133
+ info_file_handler.setFormatter(formatter)
134
 
135
  error_file_handler = RotatingFileHandler(
136
  filename=os.path.join(log_directory, "error.log"),
 
139
  backupCount=2,
140
  )
141
  error_file_handler.setLevel(logging.ERROR)
142
+ error_file_handler.setFormatter(formatter)
143
 
144
  logging.basicConfig(
145
  level=logging.DEBUG,
main.py CHANGED
@@ -25,8 +25,10 @@ from project_settings import environment, project_path, log_directory, time_zone
25
  from toolbox.os.command import Command
26
  from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
27
  from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
 
28
  from toolbox.torchaudio.utils.visualization import process_speech_probs
29
  from toolbox.vad.utils import PostProcess
 
30
 
31
  log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
32
 
@@ -93,9 +95,11 @@ def shell(cmd: str):
93
 
94
 
95
  def get_infer_cls_by_model_name(model_name: str):
96
- if model_name.__contains__("fsmn"):
 
 
97
  infer_cls = InferenceFSMNVadOnnx
98
- elif model_name.__contains__("silero"):
99
  infer_cls = InferenceSileroVad
100
  else:
101
  raise AssertionError
@@ -158,8 +162,8 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
158
  vad_info = infer_engine.infer(audio)
159
  time_cost = time.time() - begin
160
 
161
- probs = vad_info["probs"]
162
- lsnr = vad_info["lsnr"]
163
  # lsnr = lsnr / np.max(np.abs(lsnr))
164
  lsnr = lsnr / 30
165
 
@@ -197,13 +201,17 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
197
  ] for v in vad_segments
198
  ]
199
 
 
 
 
200
  # message
201
  rtf = time_cost / audio_duration
202
  info = {
203
  "vad_segments": vad_segments,
204
  "time_cost": round(time_cost, 4),
205
  "duration": round(audio_duration, 4),
206
- "rtf": round(rtf, 4)
 
207
  }
208
  message = json.dumps(info, ensure_ascii=False, indent=4)
209
 
@@ -239,8 +247,8 @@ def main():
239
  }
240
  for filename in (project_path / "trained_models").glob("*.zip")
241
  if filename.name not in (
242
- "cnn-vad-by-webrtcvad-nx-dns3.zip",
243
- "fsmn-vad-by-webrtcvad-nx-dns3.zip",
244
  "examples.zip",
245
  "sound-2-ch32.zip",
246
  "sound-3-ch32.zip",
 
25
  from toolbox.os.command import Command
26
  from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
27
  from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
28
+ from toolbox.torchaudio.models.vad.native_silero_vad.inference_native_silero_vad_onnx import InferenceNativeSileroVadOnnx
29
  from toolbox.torchaudio.utils.visualization import process_speech_probs
30
  from toolbox.vad.utils import PostProcess
31
+ from toolbox.pydub.volume import get_volume
32
 
33
  log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
34
 
 
95
 
96
 
97
  def get_infer_cls_by_model_name(model_name: str):
98
+ if model_name.__contains__("native_silero_vad"):
99
+ infer_cls = InferenceNativeSileroVadOnnx
100
+ elif model_name.__contains__("fsmn-vad"):
101
  infer_cls = InferenceFSMNVadOnnx
102
+ elif model_name.__contains__("silero-vad"):
103
  infer_cls = InferenceSileroVad
104
  else:
105
  raise AssertionError
 
162
  vad_info = infer_engine.infer(audio)
163
  time_cost = time.time() - begin
164
 
165
+ probs: np.ndarray = vad_info["probs"]
166
+ lsnr: np.ndarray = vad_info["lsnr"]
167
  # lsnr = lsnr / np.max(np.abs(lsnr))
168
  lsnr = lsnr / 30
169
 
 
201
  ] for v in vad_segments
202
  ]
203
 
204
+ # volume
205
+ volume_map: dict = get_volume(audio, sample_rate)
206
+
207
  # message
208
  rtf = time_cost / audio_duration
209
  info = {
210
  "vad_segments": vad_segments,
211
  "time_cost": round(time_cost, 4),
212
  "duration": round(audio_duration, 4),
213
+ "rtf": round(rtf, 4),
214
+ **volume_map
215
  }
216
  message = json.dumps(info, ensure_ascii=False, indent=4)
217
 
 
247
  }
248
  for filename in (project_path / "trained_models").glob("*.zip")
249
  if filename.name not in (
250
+ # "cnn-vad-by-webrtcvad-nx-dns3.zip",
251
+ # "fsmn-vad-by-webrtcvad-nx-dns3.zip",
252
  "examples.zip",
253
  "sound-2-ch32.zip",
254
  "sound-3-ch32.zip",
toolbox/pydub/volume.py CHANGED
@@ -76,6 +76,45 @@ def set_volume(waveform: np.ndarray, sample_rate: int = 8000, volume: int = 0):
76
  return samples
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def get_args():
80
  parser = argparse.ArgumentParser()
81
  parser.add_argument(
 
76
  return samples
77
 
78
 
79
+ def get_volume(waveform: np.ndarray, sample_rate: int = 8000):
80
+ if np.min(waveform) < -1 or np.max(waveform) > 1:
81
+ raise AssertionError(f"waveform type: {type(waveform)}, dtype: {waveform.dtype}")
82
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
83
+ raw_data = waveform.tobytes()
84
+
85
+ audio_segment = AudioSegment(
86
+ data=raw_data,
87
+ sample_width=2,
88
+ frame_rate=sample_rate,
89
+ channels=1
90
+ )
91
+
92
+ map_list = [
93
+ [0, -150],
94
+ [10, -40],
95
+ [50, -12],
96
+ [75, -6],
97
+ [100, 0],
98
+ ]
99
+ scores = [a for a, b in map_list]
100
+ stages = [b for a, b in map_list]
101
+
102
+ audio_dbfs = audio_segment.dBFS
103
+
104
+ # 计算目标 volume
105
+ volume = score_transform(
106
+ x=audio_dbfs,
107
+ stages=list(reversed(stages)),
108
+ scores=list(reversed(scores)),
109
+ )
110
+
111
+ result = {
112
+ "dbfs": audio_dbfs,
113
+ "volume": volume,
114
+ }
115
+ return result
116
+
117
+
118
  def get_args():
119
  parser = argparse.ArgumentParser()
120
  parser.add_argument(
toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py CHANGED
@@ -139,8 +139,9 @@ class VadPaddingJsonlDataset(IterableDataset):
139
  speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
140
 
141
  # volume enhancement
142
- volume = random.randint(10, 80)
143
- speech_wave_np = set_volume(speech_wave_np, sample_rate=self.expected_sample_rate, volume=volume)
 
144
 
145
  noise_wave_list = list()
146
  for noise in noise_list:
 
139
  speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
140
 
141
  # volume enhancement
142
+ if self.do_volume_enhancement:
143
+ volume = random.randint(10, 80)
144
+ speech_wave_np = set_volume(speech_wave_np, sample_rate=self.expected_sample_rate, volume=volume)
145
 
146
  noise_wave_list = list()
147
  for noise in noise_list:
toolbox/torchaudio/models/vad/native_silero_vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/native_silero_vad/inference_native_silero_vad_onnx.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ import shutil
7
+ import tempfile
8
+ import zipfile
9
+
10
+ from scipy.io import wavfile
11
+ import numpy as np
12
+ import torch
13
+ import onnxruntime as ort
14
+ from torch.nn import functional as F
15
+
16
+ torch.set_num_threads(1)
17
+
18
+ from project_settings import project_path
19
+ from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
20
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
21
+
22
+
23
+ logger = logging.getLogger("toolbox")
24
+
25
+
26
+ class NativeSileroVadConfig(PretrainedConfig):
27
+ def __init__(self,
28
+ sample_rate: int = 8000,
29
+ win_size: int = 256,
30
+ hop_size: int = 256,
31
+ **kwargs
32
+ ):
33
+ super(NativeSileroVadConfig, self).__init__(**kwargs)
34
+ # transform
35
+ self.sample_rate = sample_rate
36
+ self.win_size = win_size
37
+ self.hop_size = hop_size
38
+
39
+
40
+ class InferenceNativeSileroVadOnnx(object):
41
+ """
42
+ code:
43
+ https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/utils_vad.py
44
+
45
+ model:
46
+ https://github.com/snakers4/silero-vad/tree/master/src/silero_vad/data
47
+ """
48
+ def __init__(self,
49
+ pretrained_model_path_or_zip_file: str,
50
+ device: str = "cpu"
51
+ ):
52
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
53
+ self.device = torch.device(device)
54
+
55
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
56
+ config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
57
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
58
+
59
+ self.config = config
60
+ self.ort_session = ort_session
61
+
62
+ def load_models(self, model_path: str):
63
+ model_path = Path(model_path)
64
+ if model_path.name.endswith(".zip"):
65
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
66
+ out_root = Path(tempfile.gettempdir()) / "cc_vad"
67
+ out_root.mkdir(parents=True, exist_ok=True)
68
+ f_zip.extractall(path=out_root)
69
+ model_path = out_root / model_path.stem
70
+
71
+ config = NativeSileroVadConfig.from_pretrained(
72
+ pretrained_model_name_or_path=model_path.as_posix(),
73
+ )
74
+
75
+ opts = ort.SessionOptions()
76
+ opts.inter_op_num_threads = 1
77
+ opts.intra_op_num_threads = 1
78
+
79
+ ort_session = ort.InferenceSession(
80
+ (model_path / "silero_vad.onnx").as_posix(),
81
+ sess_options=opts
82
+ )
83
+ shutil.rmtree(model_path)
84
+ return config, ort_session
85
+
86
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
87
+ if signal.dim() == 2:
88
+ signal = torch.unsqueeze(signal, dim=1)
89
+ _, _, n_samples = signal.shape
90
+ remainder = (n_samples - self.config.win_size) % self.config.hop_size
91
+ if remainder > 0:
92
+ n_samples_pad = self.config.hop_size - remainder
93
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
94
+ return signal
95
+
96
+ def forward_chunk(self, chunk: torch.Tensor, context: torch.Tensor, state: torch.Tensor):
97
+ # chunk shape: [1, chunk_size]
98
+ num_samples = 512 if self.config.sample_rate == 16000 else 256
99
+ if chunk.shape[-1] != num_samples:
100
+ raise ValueError(f"Provided number of samples is {chunk.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
101
+
102
+ context_size = 64 if self.config.sample_rate == 16000 else 32
103
+
104
+ chunk = torch.cat(tensors=[context, chunk], dim=1)
105
+ input_feed = {
106
+ "input": chunk.numpy(),
107
+ "state": state.numpy(),
108
+ "sr": np.array(self.config.sample_rate, dtype=np.int64)
109
+ }
110
+ ort_outs = self.ort_session.run(output_names=None, input_feed=input_feed)
111
+ vad_flag, state = ort_outs
112
+ # vad_flag shape: [b, 1]
113
+ # state shape: [2, b, 128]
114
+ vad_flag = torch.from_numpy(vad_flag)
115
+ state = torch.from_numpy(state)
116
+ context = chunk[..., -context_size:]
117
+ return vad_flag, context, state
118
+
119
+ def infer(self, signal: np.ndarray) -> np.ndarray:
120
+ # signal shape: [num_samples,], value between -1 and 1.
121
+ inputs = torch.tensor(signal, dtype=torch.float32)
122
+ inputs = torch.unsqueeze(inputs, dim=0)
123
+ # inputs shape: [1, num_samples]
124
+
125
+ n_samples = inputs.shape[-1]
126
+ inputs = self.signal_prepare(inputs)
127
+ # inputs shape: [1, 1, num_samples]
128
+ inputs = torch.squeeze(inputs, dim=1)
129
+ # inputs shape: [1, num_samples]
130
+ _, num_samples = inputs.shape
131
+
132
+ vad_flags = list()
133
+
134
+ context = torch.zeros(0)
135
+ state = torch.zeros(size=(2, 1, 128), dtype=torch.float32)
136
+ for i in range(0, num_samples, self.config.hop_size):
137
+ sub_inputs = inputs[:, i:i+self.config.win_size]
138
+ vad_flag, context, state = self.forward_chunk(sub_inputs, context, state)
139
+ vad_flags.append(vad_flag)
140
+
141
+ vad_flags = torch.cat(vad_flags, dim=1).cpu()
142
+ # vad_flags, torch.Tensor, shape: [b, num_chunks]
143
+ vad_flags = vad_flags.numpy()
144
+ # vad_flags, np.ndarray, shape: [b, num_chunks]
145
+ vad_flags = vad_flags[0]
146
+ # vad_flags shape: [num_chunk,]
147
+
148
+ result = {
149
+ "probs": vad_flags,
150
+ "lsnr": np.zeros_like(vad_flags),
151
+ }
152
+ return result
153
+
154
+
155
+ def get_args():
156
+ parser = argparse.ArgumentParser()
157
+ parser.add_argument(
158
+ "--wav_file",
159
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
160
+ default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
161
+ type=str,
162
+ )
163
+ args = parser.parse_args()
164
+ return args
165
+
166
+
167
+ SAMPLE_RATE = 8000
168
+
169
+
170
+ def main():
171
+ args = get_args()
172
+
173
+ sample_rate, signal = wavfile.read(args.wav_file)
174
+ if SAMPLE_RATE != sample_rate:
175
+ raise AssertionError
176
+ signal = signal / (1 << 15)
177
+
178
+ infer = InferenceNativeSileroVadOnnx(
179
+ pretrained_model_path_or_zip_file=(project_path / "trained_models/native_silero_vad.zip").as_posix(),
180
+ )
181
+
182
+ vad_info = infer.infer(signal)
183
+ speech_probs = vad_info["probs"]
184
+ # speech_probs, np.ndarray shape: [num_chunk,]
185
+
186
+ speech_probs = process_speech_probs(
187
+ signal=signal,
188
+ speech_probs=speech_probs,
189
+ frame_step=infer.config.hop_size,
190
+ )
191
+
192
+ # plot
193
+ make_visualization(signal, speech_probs, SAMPLE_RATE)
194
+ return
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py CHANGED
@@ -109,9 +109,6 @@ class InferenceSileroVadOnnx(object):
109
  }
110
  return result
111
 
112
- def post_process(self, probs: List[float]):
113
- return
114
-
115
 
116
  def get_args():
117
  parser = argparse.ArgumentParser()
@@ -157,7 +154,7 @@ def main():
157
  raise AssertionError
158
  signal = signal / (1 << 15)
159
 
160
- infer = InferenceFSMNVadOnnx(
161
  # pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
162
  pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
163
  )
 
109
  }
110
  return result
111
 
 
 
 
112
 
113
  def get_args():
114
  parser = argparse.ArgumentParser()
 
154
  raise AssertionError
155
  signal = signal / (1 << 15)
156
 
157
+ infer = InferenceSileroVadOnnx(
158
  # pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
159
  pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
160
  )