Spaces:
Running
Running
import gradio as gr | |
from datetime import datetime | |
import logging | |
import os | |
import tempfile | |
import time | |
from examples import examples | |
import librosa | |
import numpy as np | |
import sentencepiece as spm | |
import onnxruntime as ort | |
# Audio parameters (FIXED) | |
SAMPLE_RATE = 16000 | |
N_FFT = 400 | |
HOP_LENGTH = 160 | |
N_MELS = 80 | |
# Tokenizer parameters | |
RNNT_BLANK = 1024 | |
PAD = 1 # tokenizer.pad_id() | |
# Training parameters | |
ATTENTION_CONTEXT_SIZE = (160, 3) | |
# Whisper-small parameters | |
N_STATE = 768 | |
N_HEAD = 12 | |
N_LAYER = 12 | |
AVERAGE_TEXT_LENGTH = 100 | |
tokenizer = spm.SentencePieceProcessor(model_file="./tokenizer_spe_bpe_v1024_pad/tokenizer.model") | |
ort_encoder_session = ort.InferenceSession("./onnx/encoder_160_8.onnx") | |
ort_decoder_session = ort.InferenceSession("./onnx/decoder_160_8.onnx") | |
ort_jointer_session = ort.InferenceSession("./onnx/jointer_160_8.onnx") | |
demo = gr.Blocks() | |
def build_html_output(s: str, style: str = "result_item_success"): | |
return f""" | |
<div class='result'> | |
<div class='result_item {style}'> | |
{s} | |
</div> | |
</div> | |
""" | |
def MyPrint(s): | |
now = datetime.now() | |
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") | |
print(f"{date_time}: {s}") | |
def process_microphone( | |
in_filename: str | |
): | |
if in_filename is None or in_filename == "": | |
return "", build_html_output( | |
"Please first click 'Record from microphone', speak, " | |
"click 'Stop recording', and then " | |
"click the button 'submit for recognition'", | |
"result_item_error", | |
) | |
MyPrint(f"Processing microphone: {in_filename}") | |
try: | |
return process( | |
in_filename=in_filename, | |
) | |
except Exception as e: | |
MyPrint(str(e)) | |
return "", build_html_output(str(e), "result_item_error") | |
def process_uploaded_file( | |
in_filename: str, | |
): | |
if in_filename is None or in_filename == "": | |
return "", build_html_output( | |
"Please first upload a file and then click " | |
'the button "submit for recognition"', | |
"result_item_error", | |
) | |
MyPrint(f"Processing uploaded file: {in_filename}") | |
try: | |
return process( | |
in_filename=in_filename | |
) | |
except Exception as e: | |
MyPrint(str(e)) | |
return "", build_html_output(str(e), "result_item_error") | |
title = "# Streaming RNN-T with Whisper Encoder" | |
description = """ | |
Visit <https://github.com/HKAB/rnnt-whisper-tutorial/> for more information. | |
- This model run on CPU | |
- This model might not work with your microphone since it was trained on a quite clean dataset. Try to speak loudly and clearly 😃 | |
""" | |
def onnx_online_inference(audio, ort_encoder_session, ort_decoder_session, ort_jointer_session, tokenizer): | |
audio = audio.astype(np.float32) | |
if audio.ndim == 1: | |
audio = np.expand_dims(audio, 0) | |
audio_cache = np.zeros((1, N_FFT - HOP_LENGTH), dtype=np.float32) | |
conv1_cache = np.zeros((1, N_MELS, 1), dtype=np.float32) | |
conv2_cache = np.zeros((1, N_STATE, 1), dtype=np.float32) | |
conv3_cache = np.zeros((1, N_STATE, 1), dtype=np.float32) | |
k_cache = np.zeros((N_LAYER, 1, ATTENTION_CONTEXT_SIZE[0], N_STATE), dtype=np.float32) | |
v_cache = np.zeros((N_LAYER, 1, ATTENTION_CONTEXT_SIZE[0], N_STATE), dtype=np.float32) | |
cache_len = np.zeros((1,), dtype=np.int32) | |
h_n = np.zeros((1, 1, N_STATE), dtype=np.float32) | |
token = np.array([[RNNT_BLANK]], dtype=np.int64) | |
seq_ids = [] | |
reset_time = 0 | |
for i in range(0, audio.shape[1], HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH)): | |
audio_chunk = audio[:, i:i+HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH)] | |
if audio_chunk.shape[1] < HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH): | |
audio_chunk = np.pad(audio_chunk, ((0, 0), (0, HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH) - audio_chunk.shape[1]))) | |
# Very simple reset mechanism | |
if len(seq_ids) // AVERAGE_TEXT_LENGTH > reset_time: | |
audio_cache = np.zeros((1, N_FFT - HOP_LENGTH), dtype=np.float32) | |
conv1_cache = np.zeros((1, N_MELS, 1), dtype=np.float32) | |
conv2_cache = np.zeros((1, N_STATE, 1), dtype=np.float32) | |
conv3_cache = np.zeros((1, N_STATE, 1), dtype=np.float32) | |
k_cache = np.zeros((N_LAYER, 1, ATTENTION_CONTEXT_SIZE[0], N_STATE), dtype=np.float32) | |
v_cache = np.zeros((N_LAYER, 1, ATTENTION_CONTEXT_SIZE[0], N_STATE), dtype=np.float32) | |
cache_len = np.zeros((1,), dtype=np.int32) | |
h_n = np.zeros((1, 1, N_STATE), dtype=np.float32) | |
token = np.array([[RNNT_BLANK]], dtype=np.int64) | |
reset_time = len(seq_ids) // AVERAGE_TEXT_LENGTH | |
# print(f"Reset hidden_state and token at {i / 16000} seconds") | |
r = ort_encoder_session.run( | |
None, | |
{ | |
"audio_chunk": audio_chunk, | |
"audio_cache.1": audio_cache, | |
"conv1_cache.1": conv1_cache, | |
"conv2_cache.1": conv2_cache, | |
"conv3_cache.1": conv3_cache, | |
"k_cache.1": k_cache, | |
"v_cache.1": v_cache, | |
"cache_len.1": cache_len | |
} | |
) | |
enc_out, audio_cache, conv1_cache, conv2_cache, conv3_cache, k_cache, v_cache, cache_len = r | |
for time_idx in range(enc_out.shape[1]): | |
curent_seq_enc_out = enc_out[:, time_idx, :].reshape(1, 1, N_STATE) | |
not_blank = True | |
symbols_added = 0 | |
while not_blank and symbols_added < 3: | |
dec, new_h_n = ort_decoder_session.run( | |
None, | |
{ | |
"token": token, | |
"h_n.1": h_n | |
} | |
) | |
logits = ort_jointer_session.run( | |
None, | |
{ | |
"enc": curent_seq_enc_out, | |
"dec": dec | |
} | |
)[0] | |
new_token = int(logits.argmax()) | |
if new_token == RNNT_BLANK: | |
not_blank = False | |
else: | |
symbols_added += 1 | |
token = np.array([[new_token]], dtype=np.int64) | |
h_n = new_h_n | |
seq_ids.append(new_token) | |
return tokenizer.decode(seq_ids) | |
def process( | |
in_filename: str, | |
): | |
# filename = convert_to_wav(in_filename) | |
now = datetime.now() | |
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") | |
MyPrint(f"Started at {date_time}") | |
start = time.time() | |
audio, _ = librosa.load(in_filename, sr=SAMPLE_RATE) | |
audio = np.pad(audio, (16000, 0)) # add some zeros to the start of the audio for warmup | |
duration = len(audio) / SAMPLE_RATE | |
audio = np.expand_dims(audio, 0).astype(np.float32) | |
text = onnx_online_inference(audio, ort_encoder_session, ort_decoder_session, ort_jointer_session, tokenizer) | |
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") | |
end = time.time() | |
rtf = (end - start) / duration | |
MyPrint(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s") | |
info = f""" | |
Wave duration : {duration: .3f} s <br/> | |
Processing time: {end - start: .3f} s <br/> | |
RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/> | |
""" | |
MyPrint(info) | |
MyPrint(f"\nPrediction: {text}") | |
return text, build_html_output(info) | |
with demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem("Upload from disk"): | |
uploaded_file = gr.Audio( | |
sources=["upload"], # Choose between "microphone", "upload" | |
type="filepath", | |
label="Upload from disk", | |
) | |
upload_button = gr.Button("Submit for recognition") | |
uploaded_output = gr.Textbox(label="Recognized speech from uploaded file") | |
uploaded_html_info = gr.HTML(label="Info") | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
uploaded_file | |
], | |
outputs=[uploaded_output, uploaded_html_info], | |
fn=process_uploaded_file, | |
label="Cherry-picked examples", | |
) | |
with gr.TabItem("Record from microphone"): | |
microphone = gr.Audio( | |
sources=["microphone"], # Choose between "microphone", "upload" | |
type="filepath", | |
label="Record from microphone", | |
) | |
record_button = gr.Button("Submit for recognition") | |
recorded_output = gr.Textbox(label="Recognized speech from recordings") | |
recorded_html_info = gr.HTML(label="Info") | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
microphone | |
], | |
outputs=[recorded_output, recorded_html_info], | |
fn=process_microphone, | |
label="Cherry-picked examples", | |
) | |
upload_button.click( | |
process_uploaded_file, | |
inputs=[ | |
uploaded_file | |
], | |
outputs=[uploaded_output, uploaded_html_info], | |
) | |
record_button.click( | |
process_microphone, | |
inputs=[ | |
microphone, | |
], | |
outputs=[recorded_output, recorded_html_info], | |
) | |
if __name__ == "__main__": | |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | |
logging.basicConfig(format=formatter, level=logging.INFO) | |
demo.launch() |