OLMoASR / app.py
Huong
Add application file
265ea18
raw
history blame
15 kB
import gradio as gr
from gradio_rich_textbox import RichTextbox
import torchaudio
import re
import librosa
import torch
import numpy as np
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from whisper.normalizers import EnglishTextNormalizer
from whisper import audio, DecodingOptions
from whisper.tokenizer import get_tokenizer
from whisper.decoding import detect_language
from olmoasr import load_model
from bs4 import BeautifulSoup
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
hf_model_path = "checkpoints/medium_hf_demo"
olmoasr_ckpt = (
"checkpoints/eval_latesttrain_00524288_medium_fsdp-train_grad-acc_bfloat16_inf.pt"
)
hf_model = AutoModelForSpeechSeq2Seq.from_pretrained(
hf_model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
hf_model.to(device).eval()
processor = AutoProcessor.from_pretrained(hf_model_path)
olmoasr_model = load_model(
name=olmoasr_ckpt, device=device, inference=True, in_memory=True
)
olmoasr_model.to(device).eval()
normalizer = EnglishTextNormalizer()
def stereo_to_mono(waveform):
# Check if the waveform is stereo
if waveform.shape[0] == 2:
# Average the two channels to convert to mono
mono_waveform = np.mean(waveform, axis=0)
return mono_waveform
else:
# If already mono, return as is
return waveform
def hf_chunk_transcribe(audio_file, timestamp_text, transcription_text):
hf_transcriber = pipeline(
"automatic-speech-recognition",
model=hf_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
chunk_length_s=30,
)
waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False)
waveform = stereo_to_mono(waveform)
print(waveform.shape)
if sample_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
result = hf_transcriber(waveform, return_timestamps=True)
print(f"{result['text']=}\n")
print(f"{result['chunks']=}\n")
# text = result["text"].strip().replace("\n", " ")
# text = re.sub(r"(foreign|foreign you|you)\s*$", "", text)
chunks, text = hf_process_chunks(result["chunks"])
print(f"{chunks=}\n")
print(f"{text=}\n")
# Edit components
transSoup = BeautifulSoup(transcription_text, "html.parser")
transText = transSoup.find(id="transcriptionText")
if transText:
transText.clear()
transText.append(BeautifulSoup(text, "html.parser"))
timeSoup = BeautifulSoup(timestamp_text, "html.parser")
timeText = timeSoup.find(id="timestampText")
if timeText:
timeText.clear()
timeText.append(BeautifulSoup(chunks, "html.parser"))
return str(timeSoup), str(transSoup)
def olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text):
waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False)
waveform = stereo_to_mono(waveform)
print(waveform.shape)
if sample_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
options = dict(
task="transcribe",
language="en",
without_timestamps=False,
beam_size=5,
best_of=5,
)
result = olmoasr_model.transcribe(waveform, verbose=False, **options)
print(f"{result['text']=}\n")
print(f"{result['segments']=}\n")
# text = result["text"].strip().replace("\n", " ")
# text = re.sub(r"(foreign|foreign you|Thank you for watching!|. you)\s*$", "", text)
chunks, text = olmoasr_process_chunks(result["segments"])
print(f"{chunks=}\n")
print(f"{text=}\n")
# Edit components
transSoup = BeautifulSoup(transcription_text, "html.parser")
transText = transSoup.find(id="transcriptionText")
if transText:
transText.clear()
transText.append(BeautifulSoup(text, "html.parser"))
timeSoup = BeautifulSoup(timestamp_text, "html.parser")
timeText = timeSoup.find(id="timestampText")
if timeText:
timeText.clear()
timeText.append(BeautifulSoup(chunks, "html.parser"))
return str(timeSoup), str(transSoup)
def hf_seq_transcribe(audio_file, timestamp_text, transcription_text):
hf_transcriber = pipeline(
"automatic-speech-recognition",
model=hf_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False)
waveform = stereo_to_mono(waveform)
print(waveform.shape)
if sample_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
result = hf_transcriber(
waveform,
return_timestamps=True,
)
print(f"{result['text']=}\n")
print(f"{result['chunks']=}\n")
# text = result["text"].strip().replace("\n", " ")
# text = re.sub(r"(foreign|foreign you|you)\s*$", "", text)
chunks, text = hf_seq_process_chunks(result["chunks"])
print(f"{text=}\n")
print(f"{chunks=}\n")
# Edit components
transSoup = BeautifulSoup(transcription_text, "html.parser")
transText = transSoup.find(id="transcriptionText")
if transText:
transText.clear()
transText.append(BeautifulSoup(text, "html.parser"))
timeSoup = BeautifulSoup(timestamp_text, "html.parser")
timeText = timeSoup.find(id="timestampText")
if timeText:
timeText.clear()
timeText.append(BeautifulSoup(chunks, "html.parser"))
return str(timeSoup), str(transSoup)
def main_transcribe(inference_strategy, audio_file, timestamp_text, transcription_text):
if inference_strategy == "HuggingFace Chunking":
return hf_chunk_transcribe(audio_file, timestamp_text, transcription_text)
elif inference_strategy == "OLMoASR Sequential":
return olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text)
elif inference_strategy == "HuggingFace Sequential":
return hf_seq_transcribe(audio_file, timestamp_text, transcription_text)
def olmoasr_process_chunks(chunks):
processed_chunks = []
processed_chunks_text = []
for chunk in chunks:
text = chunk["text"].strip()
if not re.match(
r"\s*(foreign you|foreign|Thank you for watching!|you there|you)\s*$", text
):
if text.strip() == "":
continue
start = chunk["start"]
end = chunk["end"]
pattern = r"\n(?!\d+\.\d+\s*-->)"
text = re.sub(pattern, "", text)
processed_chunks_text.append(text.strip())
processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text} <br>")
else:
break
print(f"{processed_chunks=}\n")
print(f"{processed_chunks_text=}\n")
print(
re.search(r"\s*foreign\s*$", processed_chunks_text[-1])
if processed_chunks_text
else None
)
if processed_chunks_text and re.search(
r"\s*foreign\s*$", processed_chunks_text[-1]
):
processed_chunks_text[-1] = re.sub(
r"\s*foreign\s*$", "", processed_chunks_text[-1]
)
processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1])
return "\n".join(processed_chunks), " ".join(processed_chunks_text)
def hf_process_chunks(chunks):
processed_chunks = []
processed_chunks_text = []
for chunk in chunks:
text = chunk["text"].strip()
if not re.match(r"(foreign you|foreign|you there|you)\s*$", text):
if text.strip() == "":
continue
start = chunk["timestamp"][0]
end = chunk["timestamp"][1]
pattern = r"\n(?!\d+\.\d+\s*-->)"
text = re.sub(pattern, "", text)
processed_chunks_text.append(text.strip())
processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()} <br>")
else:
break
print(f"{processed_chunks=}\n")
print(f"{processed_chunks_text=}\n")
print(
re.search(r"\s*foreign\s*$", processed_chunks_text[-1])
if processed_chunks_text
else None
)
if processed_chunks_text and re.search(
r"\s*foreign\s*$", processed_chunks_text[-1]
):
processed_chunks_text[-1] = re.sub(
r"\s*foreign\s*$", "", processed_chunks_text[-1]
)
processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1])
return "\n".join(processed_chunks), " ".join(processed_chunks_text)
def hf_seq_process_chunks(chunks):
processed_chunks = []
processed_chunks_text = []
delta_time = 0.0
global_start = chunks[0]["timestamp"][0]
prev_end = -1.0
prev_dur = 0.0
accumulate_ts = False
for chunk in chunks:
text = chunk["text"].strip()
if not re.match(r"\s*(foreign you|foreign|you there|you)\s*$", text):
if text.strip() == "":
continue
start = chunk["timestamp"][0]
if start < prev_end:
accumulate_ts = True
end = chunk["timestamp"][1]
if start < prev_end:
prev_dur += delta_time
# print(f"{prev_dur=}")
delta_time = end - global_start
# print(f"{delta_time=}")
prev_end = end
# print(f"{prev_end=}")
if accumulate_ts:
start += prev_dur
if accumulate_ts:
end += prev_dur
# print(f"{start=}, {end=}, {prev_dur=}")
pattern = r"\n(?!\d+\.\d+\s*-->)"
text = re.sub(pattern, "", text)
processed_chunks_text.append(text.strip())
processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()} <br>")
else:
break
print(f"{processed_chunks=}\n")
print(f"{processed_chunks_text=}\n")
print(
re.search(r"\s*foreign\s*$", processed_chunks_text[-1])
if processed_chunks_text
else None
)
if processed_chunks_text and re.search(
r"\s*foreign\s*$", processed_chunks_text[-1]
):
processed_chunks_text[-1] = re.sub(
r"\s*foreign\s*$", "", processed_chunks_text[-1]
)
processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1])
return "\n".join(processed_chunks), " ".join(processed_chunks_text)
original_timestamp_html = """
<div style="background: white; border: 1px solid #d1d5db; border-radius: 8px; padding: 16px; width: 100%; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); flex: 1; margin-right: 10px;">
<div style="color: #374151; font-size: 14px; font-weight: 500; margin-bottom: 8px;">Timestamp Text</div>
<div id="timestampText"; style="color: #6b7280; font-size: 14px; line-height: 1.5; min-height: 100px; font-family: system-ui, sans-serif;"></div>
</div>
"""
original_transcription_html = """
<div style="background: white; border: 1px solid #d1d5db; border-radius: 8px; padding: 16px; width: 100%; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); flex: 1; margin-right: 10px;">
<div style="color: #374151; font-size: 14px; font-weight: 500; margin-bottom: 8px;">Transcription Text</div>
<div id="transcriptionText"; style="color: #6b7280; font-size: 14px; line-height: 1.5; min-height: 100px; font-family: system-ui, sans-serif;"></div>
</div>
"""
def reset():
return original_timestamp_html, original_transcription_html
event_process_js = """
<script>
function getTime() {
lastIndex = -1;
setInterval(() => {
time = document.getElementById('time');
timestampText = document.getElementById('timestampText');
if(timestampText && timestampText.innerText != '') {
if(time == null) {
timestampText.innerText = '';
transcriptionText = document.getElementById('transcriptionText');
if(transcriptionText) {
transcriptionText.innerText = '';
}
lastIndex = -1;
return;
}
timeContent = time.textContent;
const parts = timeContent.split(":").map(Number);
currTime = parseFloat(parts[0]) * 60 + parseFloat(parts[1]);
currText = timestampText.innerText;
const matches = [...currText.matchAll(/([\d.]+)\s*-->/g)];
const startTimestamps = matches.map(m => parseFloat(m[1]));
if(startTimestamps.length != 0) {
correctIndex = 0;
for (let i = 0; i < startTimestamps.length; i++) {
if (startTimestamps[i] <= currTime) {
correctIndex = i;
}
else {
break;
}
}
if (lastIndex != correctIndex) {
lastIndex = correctIndex;
lines = currText.split('\\n');
lines[correctIndex] = '<span style="background-color: #ff69b4; padding: 3px 8px; font-weight: 500; border-radius: 4px; color: white; box-shadow: 0 0 10px rgba(255, 105, 180, 0.5);">' + lines[correctIndex] + '</span>';
try {
timestampText.innerHTML = lines.join('<br>');
}
catch (e) {
console.log('Not Updating!');
}
}
}
}
else {
lastIndex = -1;
}
}, 50);
}
setTimeout(getTime, 1000);
</script>
"""
demo = gr.Blocks(
head=event_process_js,
theme=gr.themes.Default(primary_hue="emerald", secondary_hue="green"),
)
with demo:
audio = gr.Audio(sources=["upload", "microphone"], type="filepath")
inf_strategy = gr.Dropdown(
label="Inference Strategy",
choices=[
"HuggingFace Chunking",
"HuggingFace Sequential",
"OLMoASR Sequential",
],
value="HuggingFace Chunking",
multiselect=False,
info="Select the inference strategy for transcription.",
elem_id="inf_strategy",
)
main_transcribe_button = gr.Button(
"Transcribe",
variant="primary",
)
with gr.Row():
timestampText = gr.HTML(original_timestamp_html)
transcriptionText = gr.HTML(original_transcription_html)
inf_strategy.change(
fn=reset,
inputs=[],
outputs=[timestampText, transcriptionText],
)
main_transcribe_button.click(
fn=main_transcribe,
inputs=[inf_strategy, audio, timestampText, transcriptionText],
outputs=[timestampText, transcriptionText],
)
demo.launch(share=True)