update
Browse files
.gitignore
CHANGED
|
@@ -23,3 +23,4 @@
|
|
| 23 |
**/*.wav
|
| 24 |
**/*.xlsx
|
| 25 |
**/*.jsonl
|
|
|
|
|
|
| 23 |
**/*.wav
|
| 24 |
**/*.xlsx
|
| 25 |
**/*.jsonl
|
| 26 |
+
**/*.onnx
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -126,13 +126,11 @@ fi
|
|
| 126 |
|
| 127 |
|
| 128 |
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 129 |
-
$verbose && echo "stage 4:
|
| 130 |
cd "${work_dir}" || exit 1
|
| 131 |
-
python3
|
| 132 |
-
--valid_dataset "${valid_dataset}" \
|
| 133 |
--model_dir "${file_dir}/best" \
|
| 134 |
-
--
|
| 135 |
-
--limit "${limit}" \
|
| 136 |
|
| 137 |
fi
|
| 138 |
|
|
@@ -144,7 +142,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
| 144 |
mkdir -p ${final_model_dir}
|
| 145 |
|
| 146 |
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 147 |
-
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 148 |
|
| 149 |
cd "${final_model_dir}/.." || exit 1;
|
| 150 |
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 129 |
+
$verbose && echo "stage 4: export model"
|
| 130 |
cd "${work_dir}" || exit 1
|
| 131 |
+
python3 step_5_export_model.py \
|
|
|
|
| 132 |
--model_dir "${file_dir}/best" \
|
| 133 |
+
--output_dir "${file_dir}/best" \
|
|
|
|
| 134 |
|
| 135 |
fi
|
| 136 |
|
|
|
|
| 142 |
mkdir -p ${final_model_dir}
|
| 143 |
|
| 144 |
cp "${file_dir}/best"/* "${final_model_dir}"
|
|
|
|
| 145 |
|
| 146 |
cd "${final_model_dir}/.." || exit 1;
|
| 147 |
|
examples/silero_vad_by_webrtcvad/step_5_export_model.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
| 10 |
+
|
| 11 |
+
import onnxruntime as ort
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadModelExport, SileroVadPretrainedModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_args():
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
# parser.add_argument("--model_dir", default="file_dir/best", type=str)
|
| 20 |
+
# parser.add_argument("--output_dir", default="file_dir/best", type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--model_dir",
|
| 24 |
+
default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--output_dir",
|
| 29 |
+
default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
return args
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
args = get_args()
|
| 38 |
+
|
| 39 |
+
output_dir = Path(args.output_dir)
|
| 40 |
+
output_file = output_dir / "model.onnx"
|
| 41 |
+
|
| 42 |
+
model = SileroVadPretrainedModel.from_pretrained(args.model_dir)
|
| 43 |
+
model.eval()
|
| 44 |
+
config = model.config
|
| 45 |
+
|
| 46 |
+
model_export = SileroVadModelExport(model)
|
| 47 |
+
|
| 48 |
+
encoder_num_layers = config.encoder_num_layers
|
| 49 |
+
p = (config.encoder_kernel_size - 1) // 2
|
| 50 |
+
encoder_in_channels = config.encoder_in_channels
|
| 51 |
+
encoder_hidden_channels = config.encoder_hidden_channels
|
| 52 |
+
|
| 53 |
+
decoder_num_layers = config.decoder_num_layers
|
| 54 |
+
decoder_hidden_size = config.decoder_hidden_size
|
| 55 |
+
|
| 56 |
+
b = 1
|
| 57 |
+
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
| 58 |
+
|
| 59 |
+
encoder_in_cache = torch.zeros(size=(b, 2*p, encoder_in_channels), dtype=torch.float32)
|
| 60 |
+
encoder_hidden_cache_list = [
|
| 61 |
+
torch.zeros(size=(b, 2*p, encoder_hidden_channels), dtype=torch.float32)
|
| 62 |
+
] * encoder_num_layers
|
| 63 |
+
encoder_hidden_cache_list = torch.stack(encoder_hidden_cache_list, dim=0)
|
| 64 |
+
|
| 65 |
+
lstm_hidden_state = [
|
| 66 |
+
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
| 67 |
+
] * 2
|
| 68 |
+
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
| 69 |
+
|
| 70 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = model_export.forward(
|
| 71 |
+
inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
torch.onnx.export(model_export,
|
| 75 |
+
args=(inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state),
|
| 76 |
+
f="silero_vad.onnx",
|
| 77 |
+
input_names=["inputs", "encoder_in_cache", "encoder_hidden_cache_list", "lstm_hidden_state"],
|
| 78 |
+
output_names=[
|
| 79 |
+
"logits", "probs", "lsnr",
|
| 80 |
+
"new_encoder_in_cache",
|
| 81 |
+
"new_encoder_hidden_cache_list",
|
| 82 |
+
"new_lstm_hidden_state"
|
| 83 |
+
],
|
| 84 |
+
dynamic_axes={
|
| 85 |
+
"inputs": {0: "batch_size", 2: "num_samples"},
|
| 86 |
+
"encoder_in_cache": {1: "batch_size"},
|
| 87 |
+
"encoder_hidden_cache_list": {1: "batch_size"},
|
| 88 |
+
"lstm_hidden_state": {2: "batch_size"},
|
| 89 |
+
"logits": {0: "batch_size"},
|
| 90 |
+
"probs": {0: "batch_size"},
|
| 91 |
+
"lsnr": {0: "batch_size"},
|
| 92 |
+
"new_encoder_in_cache": {1: "batch_size"},
|
| 93 |
+
"new_encoder_hidden_cache_list": {1: "batch_size"},
|
| 94 |
+
"new_lstm_hidden_state": {2: "batch_size"},
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
ort_session = ort.InferenceSession("silero_vad.onnx")
|
| 98 |
+
input_feed = {
|
| 99 |
+
"inputs": inputs.numpy(),
|
| 100 |
+
"encoder_in_cache": encoder_in_cache.numpy(),
|
| 101 |
+
"encoder_hidden_cache_list": encoder_hidden_cache_list.numpy(),
|
| 102 |
+
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
| 103 |
+
}
|
| 104 |
+
output_names = [
|
| 105 |
+
"logits", "probs", "lsnr", "new_encoder_in_cache", "new_encoder_hidden_cache_list", "new_lstm_hidden_state"
|
| 106 |
+
]
|
| 107 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = ort_session.run(output_names, input_feed)
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py
CHANGED
|
@@ -14,6 +14,8 @@ class SileroVadConfig(PretrainedConfig):
|
|
| 14 |
win_type: str = "hann",
|
| 15 |
|
| 16 |
encoder_in_channels: int = 64,
|
|
|
|
|
|
|
| 17 |
encoder_kernel_size: int = 3,
|
| 18 |
encoder_num_layers: int = 3,
|
| 19 |
|
|
@@ -52,6 +54,8 @@ class SileroVadConfig(PretrainedConfig):
|
|
| 52 |
|
| 53 |
# encoder
|
| 54 |
self.encoder_in_channels = encoder_in_channels
|
|
|
|
|
|
|
| 55 |
self.encoder_kernel_size = encoder_kernel_size
|
| 56 |
self.encoder_num_layers = encoder_num_layers
|
| 57 |
|
|
|
|
| 14 |
win_type: str = "hann",
|
| 15 |
|
| 16 |
encoder_in_channels: int = 64,
|
| 17 |
+
encoder_hidden_channels: int = 128,
|
| 18 |
+
encoder_out_channels: int = 64,
|
| 19 |
encoder_kernel_size: int = 3,
|
| 20 |
encoder_num_layers: int = 3,
|
| 21 |
|
|
|
|
| 54 |
|
| 55 |
# encoder
|
| 56 |
self.encoder_in_channels = encoder_in_channels
|
| 57 |
+
self.encoder_hidden_channels = encoder_hidden_channels
|
| 58 |
+
self.encoder_out_channels = encoder_out_channels
|
| 59 |
self.encoder_kernel_size = encoder_kernel_size
|
| 60 |
self.encoder_num_layers = encoder_num_layers
|
| 61 |
|
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py
CHANGED
|
@@ -62,6 +62,7 @@ class Encoder(nn.Module):
|
|
| 62 |
num_layers: int = 3,
|
| 63 |
):
|
| 64 |
super(Encoder, self).__init__()
|
|
|
|
| 65 |
|
| 66 |
self.layers = nn.ModuleList(modules=[])
|
| 67 |
for i in range(num_layers):
|
|
@@ -96,23 +97,33 @@ class EncoderExport(nn.Module):
|
|
| 96 |
def __init__(self, model: Encoder):
|
| 97 |
super(EncoderExport, self).__init__()
|
| 98 |
self.layers = model.layers
|
|
|
|
| 99 |
|
| 100 |
-
def forward(self, x: torch.Tensor,
|
| 101 |
# x shape: [b, t, f]
|
| 102 |
-
#
|
|
|
|
| 103 |
|
| 104 |
-
|
|
|
|
| 105 |
for idx, layer in enumerate(self.layers):
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
x_pad = torch.concat(tensors=[cache, x], dim=1)
|
| 108 |
x = layer.forward(x_pad)
|
| 109 |
|
| 110 |
_, twop, _ = cache.shape
|
| 111 |
new_cache = x_pad[:, -twop:, :]
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
return x,
|
| 116 |
|
| 117 |
|
| 118 |
class SileroVadModel(nn.Module):
|
|
@@ -123,6 +134,8 @@ class SileroVadModel(nn.Module):
|
|
| 123 |
hop_size: int,
|
| 124 |
win_type: int,
|
| 125 |
encoder_in_channels: int,
|
|
|
|
|
|
|
| 126 |
encoder_kernel_size: int,
|
| 127 |
encoder_num_layers: int,
|
| 128 |
decoder_hidden_size: int,
|
|
@@ -139,6 +152,8 @@ class SileroVadModel(nn.Module):
|
|
| 139 |
self.win_type = win_type
|
| 140 |
|
| 141 |
self.encoder_in_channels = encoder_in_channels
|
|
|
|
|
|
|
| 142 |
self.encoder_kernel_size = encoder_kernel_size
|
| 143 |
self.encoder_num_layers = encoder_num_layers
|
| 144 |
|
|
@@ -180,8 +195,8 @@ class SileroVadModel(nn.Module):
|
|
| 180 |
|
| 181 |
self.encoder = Encoder(
|
| 182 |
in_channels=self.encoder_in_channels,
|
| 183 |
-
hidden_channels=self.
|
| 184 |
-
out_channels=self.
|
| 185 |
kernel_size=self.encoder_kernel_size,
|
| 186 |
num_layers=self.encoder_num_layers,
|
| 187 |
)
|
|
@@ -298,6 +313,8 @@ class SileroVadPretrainedModel(SileroVadModel):
|
|
| 298 |
hop_size=config.hop_size,
|
| 299 |
win_type=config.win_type,
|
| 300 |
encoder_in_channels=config.encoder_in_channels,
|
|
|
|
|
|
|
| 301 |
encoder_kernel_size=config.encoder_kernel_size,
|
| 302 |
encoder_num_layers=config.encoder_num_layers,
|
| 303 |
decoder_hidden_size=config.decoder_hidden_size,
|
|
@@ -362,10 +379,12 @@ class SileroVadModelExport(nn.Module):
|
|
| 362 |
|
| 363 |
def forward(self,
|
| 364 |
signal: torch.Tensor,
|
| 365 |
-
|
|
|
|
| 366 |
lstm_hidden_state: torch.Tensor,
|
| 367 |
):
|
| 368 |
-
#
|
|
|
|
| 369 |
# lstm_hidden_state shape: [2, num_layers, b, h]
|
| 370 |
|
| 371 |
# signal shape [b, 1, num_samples]
|
|
@@ -382,7 +401,9 @@ class SileroVadModelExport(nn.Module):
|
|
| 382 |
# x = self.tpad.forward(x)
|
| 383 |
# x shape: [b, t+p, f']
|
| 384 |
|
| 385 |
-
x,
|
|
|
|
|
|
|
| 386 |
# x shape: [b, t, f']
|
| 387 |
|
| 388 |
x, new_lstm_hidden_state = self.lstm.forward(x, (lstm_hidden_state[0], lstm_hidden_state[1]))
|
|
@@ -397,7 +418,7 @@ class SileroVadModelExport(nn.Module):
|
|
| 397 |
lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
|
| 398 |
# lsnr shape: [b, t, 1]
|
| 399 |
|
| 400 |
-
return logits, probs, lsnr,
|
| 401 |
|
| 402 |
|
| 403 |
def main1():
|
|
@@ -425,6 +446,7 @@ def main2():
|
|
| 425 |
encoder_num_layers = config.encoder_num_layers
|
| 426 |
p = (config.encoder_kernel_size - 1) // 2
|
| 427 |
encoder_in_channels = config.encoder_in_channels
|
|
|
|
| 428 |
|
| 429 |
decoder_num_layers = config.decoder_num_layers
|
| 430 |
decoder_hidden_size = config.decoder_hidden_size
|
|
@@ -432,49 +454,60 @@ def main2():
|
|
| 432 |
b = 1
|
| 433 |
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
|
|
|
| 437 |
] * encoder_num_layers
|
| 438 |
-
|
| 439 |
|
| 440 |
lstm_hidden_state = [
|
| 441 |
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
| 442 |
] * 2
|
| 443 |
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
| 444 |
|
| 445 |
-
logits, probs, lsnr,
|
|
|
|
|
|
|
| 446 |
print(f"logits.shape: {logits.shape}")
|
| 447 |
-
print(f"
|
|
|
|
| 448 |
print(f"new_lstm_hidden_state.shape: {new_lstm_hidden_state.shape}")
|
| 449 |
|
| 450 |
torch.onnx.export(model_export,
|
| 451 |
-
args=(inputs,
|
| 452 |
f="silero_vad.onnx",
|
| 453 |
-
input_names=["inputs", "
|
| 454 |
-
output_names=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
dynamic_axes={
|
| 456 |
"inputs": {0: "batch_size", 2: "num_samples"},
|
| 457 |
-
"
|
|
|
|
| 458 |
"lstm_hidden_state": {2: "batch_size"},
|
| 459 |
"logits": {0: "batch_size"},
|
| 460 |
"probs": {0: "batch_size"},
|
| 461 |
"lsnr": {0: "batch_size"},
|
| 462 |
-
"
|
|
|
|
| 463 |
"new_lstm_hidden_state": {2: "batch_size"},
|
| 464 |
})
|
| 465 |
|
| 466 |
ort_session = ort.InferenceSession("silero_vad.onnx")
|
| 467 |
input_feed = {
|
| 468 |
"inputs": inputs.numpy(),
|
| 469 |
-
"
|
|
|
|
| 470 |
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
| 471 |
}
|
| 472 |
output_names = [
|
| 473 |
-
"logits", "probs", "lsnr", "
|
| 474 |
]
|
| 475 |
-
logits, probs, lsnr,
|
| 476 |
print(f"probs.shape: {probs.shape}")
|
| 477 |
-
print(f"new_encoder_cache_list.shape: {new_encoder_cache_list.shape}")
|
| 478 |
return
|
| 479 |
|
| 480 |
|
|
|
|
| 62 |
num_layers: int = 3,
|
| 63 |
):
|
| 64 |
super(Encoder, self).__init__()
|
| 65 |
+
self.num_layers = num_layers
|
| 66 |
|
| 67 |
self.layers = nn.ModuleList(modules=[])
|
| 68 |
for i in range(num_layers):
|
|
|
|
| 97 |
def __init__(self, model: Encoder):
|
| 98 |
super(EncoderExport, self).__init__()
|
| 99 |
self.layers = model.layers
|
| 100 |
+
self.num_layers = model.num_layers
|
| 101 |
|
| 102 |
+
def forward(self, x: torch.Tensor, in_cache: torch.Tensor, hidden_cache_list: torch.Tensor):
|
| 103 |
# x shape: [b, t, f]
|
| 104 |
+
# in_cache shape: [b, 2p, f1]
|
| 105 |
+
# hidden_cache_list shape: [num_layers, b, 2p, fi]
|
| 106 |
|
| 107 |
+
new_in_cache = None
|
| 108 |
+
new_hidden_cache_list = list()
|
| 109 |
for idx, layer in enumerate(self.layers):
|
| 110 |
+
if idx == 0:
|
| 111 |
+
cache = in_cache
|
| 112 |
+
else:
|
| 113 |
+
cache = hidden_cache_list[idx]
|
| 114 |
+
|
| 115 |
x_pad = torch.concat(tensors=[cache, x], dim=1)
|
| 116 |
x = layer.forward(x_pad)
|
| 117 |
|
| 118 |
_, twop, _ = cache.shape
|
| 119 |
new_cache = x_pad[:, -twop:, :]
|
| 120 |
+
if idx == 0:
|
| 121 |
+
new_in_cache = new_cache
|
| 122 |
+
else:
|
| 123 |
+
new_hidden_cache_list.append(new_cache)
|
| 124 |
|
| 125 |
+
new_hidden_cache_list = torch.stack(tensors=new_hidden_cache_list, dim=0)
|
| 126 |
+
return x, new_in_cache, new_hidden_cache_list
|
| 127 |
|
| 128 |
|
| 129 |
class SileroVadModel(nn.Module):
|
|
|
|
| 134 |
hop_size: int,
|
| 135 |
win_type: int,
|
| 136 |
encoder_in_channels: int,
|
| 137 |
+
encoder_hidden_channels: int,
|
| 138 |
+
encoder_out_channels: int,
|
| 139 |
encoder_kernel_size: int,
|
| 140 |
encoder_num_layers: int,
|
| 141 |
decoder_hidden_size: int,
|
|
|
|
| 152 |
self.win_type = win_type
|
| 153 |
|
| 154 |
self.encoder_in_channels = encoder_in_channels
|
| 155 |
+
self.encoder_hidden_channels = encoder_hidden_channels
|
| 156 |
+
self.encoder_out_channels = encoder_out_channels
|
| 157 |
self.encoder_kernel_size = encoder_kernel_size
|
| 158 |
self.encoder_num_layers = encoder_num_layers
|
| 159 |
|
|
|
|
| 195 |
|
| 196 |
self.encoder = Encoder(
|
| 197 |
in_channels=self.encoder_in_channels,
|
| 198 |
+
hidden_channels=self.encoder_hidden_channels,
|
| 199 |
+
out_channels=self.encoder_out_channels,
|
| 200 |
kernel_size=self.encoder_kernel_size,
|
| 201 |
num_layers=self.encoder_num_layers,
|
| 202 |
)
|
|
|
|
| 313 |
hop_size=config.hop_size,
|
| 314 |
win_type=config.win_type,
|
| 315 |
encoder_in_channels=config.encoder_in_channels,
|
| 316 |
+
encoder_hidden_channels=config.encoder_hidden_channels,
|
| 317 |
+
encoder_out_channels=config.encoder_out_channels,
|
| 318 |
encoder_kernel_size=config.encoder_kernel_size,
|
| 319 |
encoder_num_layers=config.encoder_num_layers,
|
| 320 |
decoder_hidden_size=config.decoder_hidden_size,
|
|
|
|
| 379 |
|
| 380 |
def forward(self,
|
| 381 |
signal: torch.Tensor,
|
| 382 |
+
encoder_in_cache: torch.Tensor,
|
| 383 |
+
encoder_hidden_cache_list: torch.Tensor,
|
| 384 |
lstm_hidden_state: torch.Tensor,
|
| 385 |
):
|
| 386 |
+
# encoder_in_cache shape: [b, 2p, f]
|
| 387 |
+
# encoder_hidden_cache_list shape: [num_layers, b, 2p, f]
|
| 388 |
# lstm_hidden_state shape: [2, num_layers, b, h]
|
| 389 |
|
| 390 |
# signal shape [b, 1, num_samples]
|
|
|
|
| 401 |
# x = self.tpad.forward(x)
|
| 402 |
# x shape: [b, t+p, f']
|
| 403 |
|
| 404 |
+
x, new_encoder_in_cache, new_encoder_hidden_cache_list = self.encoder.forward(
|
| 405 |
+
x, in_cache=encoder_in_cache, hidden_cache_list=encoder_hidden_cache_list
|
| 406 |
+
)
|
| 407 |
# x shape: [b, t, f']
|
| 408 |
|
| 409 |
x, new_lstm_hidden_state = self.lstm.forward(x, (lstm_hidden_state[0], lstm_hidden_state[1]))
|
|
|
|
| 418 |
lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
|
| 419 |
# lsnr shape: [b, t, 1]
|
| 420 |
|
| 421 |
+
return logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state
|
| 422 |
|
| 423 |
|
| 424 |
def main1():
|
|
|
|
| 446 |
encoder_num_layers = config.encoder_num_layers
|
| 447 |
p = (config.encoder_kernel_size - 1) // 2
|
| 448 |
encoder_in_channels = config.encoder_in_channels
|
| 449 |
+
encoder_hidden_channels = config.encoder_hidden_channels
|
| 450 |
|
| 451 |
decoder_num_layers = config.decoder_num_layers
|
| 452 |
decoder_hidden_size = config.decoder_hidden_size
|
|
|
|
| 454 |
b = 1
|
| 455 |
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
| 456 |
|
| 457 |
+
encoder_in_cache = torch.zeros(size=(b, 2*p, encoder_in_channels), dtype=torch.float32)
|
| 458 |
+
encoder_hidden_cache_list = [
|
| 459 |
+
torch.zeros(size=(b, 2*p, encoder_hidden_channels), dtype=torch.float32)
|
| 460 |
] * encoder_num_layers
|
| 461 |
+
encoder_hidden_cache_list = torch.stack(encoder_hidden_cache_list, dim=0)
|
| 462 |
|
| 463 |
lstm_hidden_state = [
|
| 464 |
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
| 465 |
] * 2
|
| 466 |
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
| 467 |
|
| 468 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = model_export.forward(
|
| 469 |
+
inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state
|
| 470 |
+
)
|
| 471 |
print(f"logits.shape: {logits.shape}")
|
| 472 |
+
print(f"new_encoder_in_cache.shape: {new_encoder_in_cache.shape}")
|
| 473 |
+
print(f"new_encoder_hidden_cache_list.shape: {new_encoder_hidden_cache_list.shape}")
|
| 474 |
print(f"new_lstm_hidden_state.shape: {new_lstm_hidden_state.shape}")
|
| 475 |
|
| 476 |
torch.onnx.export(model_export,
|
| 477 |
+
args=(inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state),
|
| 478 |
f="silero_vad.onnx",
|
| 479 |
+
input_names=["inputs", "encoder_in_cache", "encoder_hidden_cache_list", "lstm_hidden_state"],
|
| 480 |
+
output_names=[
|
| 481 |
+
"logits", "probs", "lsnr",
|
| 482 |
+
"new_encoder_in_cache",
|
| 483 |
+
"new_encoder_hidden_cache_list",
|
| 484 |
+
"new_lstm_hidden_state"
|
| 485 |
+
],
|
| 486 |
dynamic_axes={
|
| 487 |
"inputs": {0: "batch_size", 2: "num_samples"},
|
| 488 |
+
"encoder_in_cache": {1: "batch_size"},
|
| 489 |
+
"encoder_hidden_cache_list": {1: "batch_size"},
|
| 490 |
"lstm_hidden_state": {2: "batch_size"},
|
| 491 |
"logits": {0: "batch_size"},
|
| 492 |
"probs": {0: "batch_size"},
|
| 493 |
"lsnr": {0: "batch_size"},
|
| 494 |
+
"new_encoder_in_cache": {1: "batch_size"},
|
| 495 |
+
"new_encoder_hidden_cache_list": {1: "batch_size"},
|
| 496 |
"new_lstm_hidden_state": {2: "batch_size"},
|
| 497 |
})
|
| 498 |
|
| 499 |
ort_session = ort.InferenceSession("silero_vad.onnx")
|
| 500 |
input_feed = {
|
| 501 |
"inputs": inputs.numpy(),
|
| 502 |
+
"encoder_in_cache": encoder_in_cache.numpy(),
|
| 503 |
+
"encoder_hidden_cache_list": encoder_hidden_cache_list.numpy(),
|
| 504 |
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
| 505 |
}
|
| 506 |
output_names = [
|
| 507 |
+
"logits", "probs", "lsnr", "new_encoder_in_cache", "new_encoder_hidden_cache_list", "new_lstm_hidden_state"
|
| 508 |
]
|
| 509 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = ort_session.run(output_names, input_feed)
|
| 510 |
print(f"probs.shape: {probs.shape}")
|
|
|
|
| 511 |
return
|
| 512 |
|
| 513 |
|