|
|
|
import install_dependencies |
|
|
|
install_dependencies.install_private_repos() |
|
|
|
import gradio as gr |
|
from gradio_rich_textbox import RichTextbox |
|
import torchaudio |
|
import re |
|
import librosa |
|
import torch |
|
import numpy as np |
|
import os |
|
import tempfile |
|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
from huggingface_hub import hf_hub_download |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
from whisper.normalizers import EnglishTextNormalizer |
|
from whisper import audio, DecodingOptions |
|
from whisper.tokenizer import get_tokenizer |
|
from whisper.decoding import detect_language |
|
from olmoasr import load_model |
|
from bs4 import BeautifulSoup |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
OLMOASR_REPO = "allenai/OLMoASR" |
|
CHECKPOINT_FILENAME = "tiny.en.pt" |
|
LOCAL_CHECKPOINT_DIR = "checkpoint_tiny" |
|
HF_MODEL_DIR = "tiny_hf" |
|
|
|
|
|
def ensure_checkpoint_dir(): |
|
"""Ensure the checkpoint directory exists.""" |
|
Path(LOCAL_CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) |
|
Path(HF_MODEL_DIR).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def download_olmoasr_checkpoint(): |
|
"""Download OLMoASR checkpoint from HuggingFace hub.""" |
|
ensure_checkpoint_dir() |
|
|
|
local_checkpoint_path = os.path.join(LOCAL_CHECKPOINT_DIR, CHECKPOINT_FILENAME) |
|
|
|
|
|
if os.path.exists(local_checkpoint_path): |
|
print(f"Checkpoint already exists at {local_checkpoint_path}") |
|
return local_checkpoint_path |
|
|
|
try: |
|
print(f"Downloading checkpoint from {OLMOASR_REPO}") |
|
downloaded_path = hf_hub_download( |
|
repo_id=OLMOASR_REPO, |
|
filename=CHECKPOINT_FILENAME, |
|
local_dir=LOCAL_CHECKPOINT_DIR, |
|
local_dir_use_symlinks=False, |
|
token=os.getenv("HF_TOKEN"), |
|
) |
|
print(f"Downloaded checkpoint to {downloaded_path}") |
|
return downloaded_path |
|
except Exception as e: |
|
print(f"Error downloading checkpoint: {e}") |
|
|
|
|
|
def convert_checkpoint_to_hf(checkpoint_path): |
|
"""Convert OLMoASR checkpoint to HuggingFace format using subprocess.""" |
|
if os.path.exists(os.path.join(HF_MODEL_DIR, "config.json")): |
|
print(f"HuggingFace model already exists at {HF_MODEL_DIR}") |
|
return HF_MODEL_DIR |
|
|
|
try: |
|
print(f"Converting checkpoint {checkpoint_path} to HuggingFace format") |
|
|
|
|
|
script_path = os.path.join(os.path.dirname(__file__), "convert_openai_to_hf.py") |
|
|
|
|
|
cmd = [ |
|
sys.executable, |
|
script_path, |
|
"--checkpoint_path", |
|
checkpoint_path, |
|
"--pytorch_dump_folder_path", |
|
HF_MODEL_DIR, |
|
"--convert_preprocessor", |
|
"True", |
|
] |
|
|
|
print(f"Running conversion command: {' '.join(cmd)}") |
|
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True, check=True) |
|
|
|
print("Conversion output:") |
|
print(result.stdout) |
|
|
|
if result.stderr: |
|
print("Conversion warnings/errors:") |
|
print(result.stderr) |
|
|
|
|
|
if os.path.exists(os.path.join(HF_MODEL_DIR, "config.json")): |
|
print(f"Model successfully converted and saved to {HF_MODEL_DIR}") |
|
return HF_MODEL_DIR |
|
else: |
|
raise Exception("Conversion completed but config.json not found") |
|
|
|
except subprocess.CalledProcessError as e: |
|
print(f"Conversion script failed with return code {e.returncode}") |
|
print(f"stdout: {e.stdout}") |
|
print(f"stderr: {e.stderr}") |
|
raise e |
|
except Exception as e: |
|
print(f"Error converting checkpoint: {e}") |
|
raise e |
|
|
|
|
|
def initialize_models(): |
|
"""Initialize both HuggingFace and OLMoASR models.""" |
|
|
|
|
|
|
|
olmoasr_ckpt = os.path.join(LOCAL_CHECKPOINT_DIR, CHECKPOINT_FILENAME) |
|
hf_model_path = HF_MODEL_DIR |
|
|
|
|
|
hf_model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
hf_model_path, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
) |
|
hf_model.to(device).eval() |
|
processor = AutoProcessor.from_pretrained(hf_model_path) |
|
|
|
|
|
olmoasr_model = load_model( |
|
name=olmoasr_ckpt, device=device, inference=True, in_memory=True |
|
) |
|
olmoasr_model.to(device).eval() |
|
|
|
return hf_model, processor, olmoasr_model |
|
|
|
|
|
|
|
print("Initializing models...") |
|
hf_model, processor, olmoasr_model = initialize_models() |
|
print("Models initialized successfully!") |
|
|
|
normalizer = EnglishTextNormalizer() |
|
|
|
|
|
def stereo_to_mono(waveform): |
|
|
|
if waveform.shape[0] == 2: |
|
|
|
mono_waveform = np.mean(waveform, axis=0) |
|
return mono_waveform |
|
else: |
|
|
|
return waveform |
|
|
|
|
|
def hf_chunk_transcribe(audio_file, timestamp_text, transcription_text): |
|
hf_transcriber = pipeline( |
|
"automatic-speech-recognition", |
|
model=hf_model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
chunk_length_s=30, |
|
) |
|
|
|
waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) |
|
waveform = stereo_to_mono(waveform) |
|
print(waveform.shape) |
|
|
|
if sample_rate != 16000: |
|
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) |
|
|
|
result = hf_transcriber(waveform, return_timestamps=True) |
|
print(f"{result['text']=}\n") |
|
print(f"{result['chunks']=}\n") |
|
|
|
|
|
|
|
|
|
chunks, text = hf_process_chunks(result["chunks"]) |
|
print(f"{chunks=}\n") |
|
print(f"{text=}\n") |
|
|
|
|
|
transSoup = BeautifulSoup(transcription_text, "html.parser") |
|
transText = transSoup.find(id="transcriptionText") |
|
if transText: |
|
transText.clear() |
|
transText.append(BeautifulSoup(text, "html.parser")) |
|
|
|
timeSoup = BeautifulSoup(timestamp_text, "html.parser") |
|
timeText = timeSoup.find(id="timestampText") |
|
if timeText: |
|
timeText.clear() |
|
timeText.append(BeautifulSoup(chunks, "html.parser")) |
|
|
|
return str(timeSoup), str(transSoup) |
|
|
|
|
|
def olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text): |
|
waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) |
|
waveform = stereo_to_mono(waveform) |
|
print(waveform.shape) |
|
|
|
if sample_rate != 16000: |
|
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) |
|
|
|
options = dict( |
|
task="transcribe", |
|
language="en", |
|
without_timestamps=False, |
|
beam_size=5, |
|
best_of=5, |
|
) |
|
result = olmoasr_model.transcribe(waveform, verbose=False, **options) |
|
print(f"{result['text']=}\n") |
|
print(f"{result['segments']=}\n") |
|
|
|
|
|
|
|
|
|
chunks, text = olmoasr_process_chunks(result["segments"]) |
|
print(f"{chunks=}\n") |
|
print(f"{text=}\n") |
|
|
|
|
|
transSoup = BeautifulSoup(transcription_text, "html.parser") |
|
transText = transSoup.find(id="transcriptionText") |
|
if transText: |
|
transText.clear() |
|
transText.append(BeautifulSoup(text, "html.parser")) |
|
|
|
timeSoup = BeautifulSoup(timestamp_text, "html.parser") |
|
timeText = timeSoup.find(id="timestampText") |
|
if timeText: |
|
timeText.clear() |
|
timeText.append(BeautifulSoup(chunks, "html.parser")) |
|
|
|
return str(timeSoup), str(transSoup) |
|
|
|
|
|
def hf_seq_transcribe(audio_file, timestamp_text, transcription_text): |
|
hf_transcriber = pipeline( |
|
"automatic-speech-recognition", |
|
model=hf_model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
|
|
waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) |
|
waveform = stereo_to_mono(waveform) |
|
print(waveform.shape) |
|
|
|
if sample_rate != 16000: |
|
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) |
|
|
|
result = hf_transcriber( |
|
waveform, |
|
return_timestamps=True, |
|
) |
|
print(f"{result['text']=}\n") |
|
print(f"{result['chunks']=}\n") |
|
|
|
|
|
|
|
|
|
chunks, text = hf_seq_process_chunks(result["chunks"]) |
|
print(f"{text=}\n") |
|
print(f"{chunks=}\n") |
|
|
|
|
|
transSoup = BeautifulSoup(transcription_text, "html.parser") |
|
transText = transSoup.find(id="transcriptionText") |
|
if transText: |
|
transText.clear() |
|
transText.append(BeautifulSoup(text, "html.parser")) |
|
|
|
timeSoup = BeautifulSoup(timestamp_text, "html.parser") |
|
timeText = timeSoup.find(id="timestampText") |
|
if timeText: |
|
timeText.clear() |
|
timeText.append(BeautifulSoup(chunks, "html.parser")) |
|
|
|
return str(timeSoup), str(transSoup) |
|
|
|
|
|
def main_transcribe(inference_strategy, audio_file, timestamp_text, transcription_text): |
|
if inference_strategy == "HuggingFace Chunking": |
|
return hf_chunk_transcribe(audio_file, timestamp_text, transcription_text) |
|
elif inference_strategy == "OLMoASR Sequential": |
|
return olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text) |
|
elif inference_strategy == "HuggingFace Sequential": |
|
return hf_seq_transcribe(audio_file, timestamp_text, transcription_text) |
|
|
|
|
|
def olmoasr_process_chunks(chunks): |
|
processed_chunks = [] |
|
processed_chunks_text = [] |
|
for chunk in chunks: |
|
text = chunk["text"].strip() |
|
if not re.match( |
|
r"\s*(foreign you|foreign|Thank you for watching!|you there|you)\s*$", text |
|
): |
|
if text.strip() == "": |
|
continue |
|
start = chunk["start"] |
|
end = chunk["end"] |
|
pattern = r"\n(?!\d+\.\d+\s*-->)" |
|
text = re.sub(pattern, "", text) |
|
processed_chunks_text.append(text.strip()) |
|
processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text} <br>") |
|
else: |
|
break |
|
print(f"{processed_chunks=}\n") |
|
print(f"{processed_chunks_text=}\n") |
|
print( |
|
re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) |
|
if processed_chunks_text |
|
else None |
|
) |
|
|
|
if processed_chunks_text and re.search( |
|
r"\s*foreign\s*$", processed_chunks_text[-1] |
|
): |
|
processed_chunks_text[-1] = re.sub( |
|
r"\s*foreign\s*$", "", processed_chunks_text[-1] |
|
) |
|
processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1]) |
|
return "\n".join(processed_chunks), " ".join(processed_chunks_text) |
|
|
|
|
|
def hf_process_chunks(chunks): |
|
processed_chunks = [] |
|
processed_chunks_text = [] |
|
for chunk in chunks: |
|
text = chunk["text"].strip() |
|
if not re.match(r"(foreign you|foreign|you there|you)\s*$", text): |
|
if text.strip() == "": |
|
continue |
|
start = chunk["timestamp"][0] |
|
end = chunk["timestamp"][1] |
|
pattern = r"\n(?!\d+\.\d+\s*-->)" |
|
text = re.sub(pattern, "", text) |
|
processed_chunks_text.append(text.strip()) |
|
processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()} <br>") |
|
else: |
|
break |
|
print(f"{processed_chunks=}\n") |
|
print(f"{processed_chunks_text=}\n") |
|
print( |
|
re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) |
|
if processed_chunks_text |
|
else None |
|
) |
|
|
|
if processed_chunks_text and re.search( |
|
r"\s*foreign\s*$", processed_chunks_text[-1] |
|
): |
|
processed_chunks_text[-1] = re.sub( |
|
r"\s*foreign\s*$", "", processed_chunks_text[-1] |
|
) |
|
processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1]) |
|
return "\n".join(processed_chunks), " ".join(processed_chunks_text) |
|
|
|
|
|
def hf_seq_process_chunks(chunks): |
|
processed_chunks = [] |
|
processed_chunks_text = [] |
|
delta_time = 0.0 |
|
global_start = chunks[0]["timestamp"][0] |
|
prev_end = -1.0 |
|
prev_dur = 0.0 |
|
accumulate_ts = False |
|
for chunk in chunks: |
|
text = chunk["text"].strip() |
|
if not re.match(r"\s*(foreign you|foreign|you there|you)\s*$", text): |
|
if text.strip() == "": |
|
continue |
|
start = chunk["timestamp"][0] |
|
if start < prev_end: |
|
accumulate_ts = True |
|
end = chunk["timestamp"][1] |
|
if start < prev_end: |
|
prev_dur += delta_time |
|
|
|
|
|
delta_time = end - global_start |
|
|
|
|
|
prev_end = end |
|
|
|
if accumulate_ts: |
|
start += prev_dur |
|
if accumulate_ts: |
|
end += prev_dur |
|
|
|
|
|
pattern = r"\n(?!\d+\.\d+\s*-->)" |
|
text = re.sub(pattern, "", text) |
|
processed_chunks_text.append(text.strip()) |
|
processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()} <br>") |
|
else: |
|
break |
|
print(f"{processed_chunks=}\n") |
|
print(f"{processed_chunks_text=}\n") |
|
print( |
|
re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) |
|
if processed_chunks_text |
|
else None |
|
) |
|
|
|
if processed_chunks_text and re.search( |
|
r"\s*foreign\s*$", processed_chunks_text[-1] |
|
): |
|
processed_chunks_text[-1] = re.sub( |
|
r"\s*foreign\s*$", "", processed_chunks_text[-1] |
|
) |
|
processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1]) |
|
return "\n".join(processed_chunks), " ".join(processed_chunks_text) |
|
|
|
|
|
original_timestamp_html = """ |
|
<div style="background: white; border: 1px solid #d1d5db; border-radius: 8px; padding: 16px; width: 100%; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); flex: 1; margin-right: 10px;"> |
|
<div style="color: #374151; font-size: 14px; font-weight: 500; margin-bottom: 8px;">Timestamp Text</div> |
|
<div id="timestampText"; style="color: #6b7280; font-size: 14px; line-height: 1.5; min-height: 100px; font-family: system-ui, sans-serif;"></div> |
|
</div> |
|
""" |
|
|
|
original_transcription_html = """ |
|
<div style="background: white; border: 1px solid #d1d5db; border-radius: 8px; padding: 16px; width: 100%; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); flex: 1; margin-right: 10px;"> |
|
<div style="color: #374151; font-size: 14px; font-weight: 500; margin-bottom: 8px;">Transcription Text</div> |
|
<div id="transcriptionText"; style="color: #6b7280; font-size: 14px; line-height: 1.5; min-height: 100px; font-family: system-ui, sans-serif;"></div> |
|
</div> |
|
""" |
|
|
|
|
|
def reset(): |
|
return original_timestamp_html, original_transcription_html |
|
|
|
|
|
event_process_js = """ |
|
<script> |
|
function getTime() { |
|
lastIndex = -1; |
|
setInterval(() => { |
|
time = document.getElementById('time'); |
|
timestampText = document.getElementById('timestampText'); |
|
if(timestampText && timestampText.innerText != '') { |
|
if(time == null) { |
|
timestampText.innerText = ''; |
|
transcriptionText = document.getElementById('transcriptionText'); |
|
if(transcriptionText) { |
|
transcriptionText.innerText = ''; |
|
} |
|
lastIndex = -1; |
|
return; |
|
} |
|
timeContent = time.textContent; |
|
const parts = timeContent.split(":").map(Number); |
|
currTime = parseFloat(parts[0]) * 60 + parseFloat(parts[1]); |
|
currText = timestampText.innerText; |
|
const matches = [...currText.matchAll(/([\d.]+)\s*-->/g)]; |
|
const startTimestamps = matches.map(m => parseFloat(m[1])); |
|
|
|
if(startTimestamps.length != 0) { |
|
correctIndex = 0; |
|
for (let i = 0; i < startTimestamps.length; i++) { |
|
if (startTimestamps[i] <= currTime) { |
|
correctIndex = i; |
|
} |
|
else { |
|
break; |
|
} |
|
} |
|
if (lastIndex != correctIndex) { |
|
lastIndex = correctIndex; |
|
lines = currText.split('\\n'); |
|
lines[correctIndex] = '<span style="background-color: #ff69b4; padding: 3px 8px; font-weight: 500; border-radius: 4px; color: white; box-shadow: 0 0 10px rgba(255, 105, 180, 0.5);">' + lines[correctIndex] + '</span>'; |
|
try { |
|
timestampText.innerHTML = lines.join('<br>'); |
|
} |
|
catch (e) { |
|
console.log('Not Updating!'); |
|
} |
|
} |
|
|
|
} |
|
} |
|
else { |
|
lastIndex = -1; |
|
} |
|
}, 50); |
|
} |
|
setTimeout(getTime, 1000); |
|
</script> |
|
""" |
|
demo = gr.Blocks( |
|
head=event_process_js, |
|
theme=gr.themes.Default(primary_hue="emerald", secondary_hue="green"), |
|
) |
|
with demo: |
|
audio = gr.Audio(sources=["upload", "microphone"], type="filepath") |
|
inf_strategy = gr.Dropdown( |
|
label="Inference Strategy", |
|
choices=[ |
|
"HuggingFace Chunking", |
|
"HuggingFace Sequential", |
|
"OLMoASR Sequential", |
|
], |
|
value="HuggingFace Chunking", |
|
multiselect=False, |
|
info="Select the inference strategy for transcription.", |
|
elem_id="inf_strategy", |
|
) |
|
main_transcribe_button = gr.Button( |
|
"Transcribe", |
|
variant="primary", |
|
) |
|
with gr.Row(): |
|
timestampText = gr.HTML(original_timestamp_html) |
|
|
|
transcriptionText = gr.HTML(original_transcription_html) |
|
inf_strategy.change( |
|
fn=reset, |
|
inputs=[], |
|
outputs=[timestampText, transcriptionText], |
|
) |
|
main_transcribe_button.click( |
|
fn=main_transcribe, |
|
inputs=[inf_strategy, audio, timestampText, transcriptionText], |
|
outputs=[timestampText, transcriptionText], |
|
) |
|
demo.launch(share=True) |
|
|