# Install private dependencies first import install_dependencies install_dependencies.install_private_repos() import gradio as gr from gradio_rich_textbox import RichTextbox import torchaudio import re import librosa import torch import numpy as np import os import tempfile import subprocess import sys from pathlib import Path from huggingface_hub import hf_hub_download from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from whisper.normalizers import EnglishTextNormalizer from whisper import audio, DecodingOptions from whisper.tokenizer import get_tokenizer from whisper.decoding import detect_language from olmoasr import load_model from bs4 import BeautifulSoup device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Configuration for model download and conversion OLMOASR_REPO = "allenai/OLMoASR" # Temporary model link as requested CHECKPOINT_FILENAME = "tiny.en.pt" # Adjust based on actual filename in the repo LOCAL_CHECKPOINT_DIR = "checkpoint_tiny" HF_MODEL_DIR = "tiny_hf" def ensure_checkpoint_dir(): """Ensure the checkpoint directory exists.""" Path(LOCAL_CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) Path(HF_MODEL_DIR).mkdir(parents=True, exist_ok=True) def download_olmoasr_checkpoint(): """Download OLMoASR checkpoint from HuggingFace hub.""" ensure_checkpoint_dir() local_checkpoint_path = os.path.join(LOCAL_CHECKPOINT_DIR, CHECKPOINT_FILENAME) # Check if checkpoint already exists if os.path.exists(local_checkpoint_path): print(f"Checkpoint already exists at {local_checkpoint_path}") return local_checkpoint_path try: print(f"Downloading checkpoint from {OLMOASR_REPO}") downloaded_path = hf_hub_download( repo_id=OLMOASR_REPO, filename=CHECKPOINT_FILENAME, local_dir=LOCAL_CHECKPOINT_DIR, local_dir_use_symlinks=False, token=os.getenv("HF_TOKEN"), ) print(f"Downloaded checkpoint to {downloaded_path}") return downloaded_path except Exception as e: print(f"Error downloading checkpoint: {e}") def convert_checkpoint_to_hf(checkpoint_path): """Convert OLMoASR checkpoint to HuggingFace format using subprocess.""" if os.path.exists(os.path.join(HF_MODEL_DIR, "config.json")): print(f"HuggingFace model already exists at {HF_MODEL_DIR}") return HF_MODEL_DIR try: print(f"Converting checkpoint {checkpoint_path} to HuggingFace format") # Path to the conversion script script_path = os.path.join(os.path.dirname(__file__), "convert_openai_to_hf.py") # Run the conversion script using subprocess cmd = [ sys.executable, script_path, "--checkpoint_path", checkpoint_path, "--pytorch_dump_folder_path", HF_MODEL_DIR, "--convert_preprocessor", "True", ] print(f"Running conversion command: {' '.join(cmd)}") # Execute the conversion script result = subprocess.run(cmd, capture_output=True, text=True, check=True) print("Conversion output:") print(result.stdout) if result.stderr: print("Conversion warnings/errors:") print(result.stderr) # Verify that the conversion was successful if os.path.exists(os.path.join(HF_MODEL_DIR, "config.json")): print(f"Model successfully converted and saved to {HF_MODEL_DIR}") return HF_MODEL_DIR else: raise Exception("Conversion completed but config.json not found") except subprocess.CalledProcessError as e: print(f"Conversion script failed with return code {e.returncode}") print(f"stdout: {e.stdout}") print(f"stderr: {e.stderr}") raise e except Exception as e: print(f"Error converting checkpoint: {e}") raise e def initialize_models(): """Initialize both HuggingFace and OLMoASR models.""" # Download and convert HuggingFace model #checkpoint_path = download_olmoasr_checkpoint() #hf_model_path = convert_checkpoint_to_hf(checkpoint_path) olmoasr_ckpt = os.path.join(LOCAL_CHECKPOINT_DIR, CHECKPOINT_FILENAME) hf_model_path = HF_MODEL_DIR # Load HuggingFace model hf_model = AutoModelForSpeechSeq2Seq.from_pretrained( hf_model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, ) hf_model.to(device).eval() processor = AutoProcessor.from_pretrained(hf_model_path) # Load OLMoASR model olmoasr_model = load_model( name=olmoasr_ckpt, device=device, inference=True, in_memory=True ) olmoasr_model.to(device).eval() return hf_model, processor, olmoasr_model # Initialize models print("Initializing models...") hf_model, processor, olmoasr_model = initialize_models() print("Models initialized successfully!") normalizer = EnglishTextNormalizer() def stereo_to_mono(waveform): # Check if the waveform is stereo if waveform.shape[0] == 2: # Average the two channels to convert to mono mono_waveform = np.mean(waveform, axis=0) return mono_waveform else: # If already mono, return as is return waveform def hf_chunk_transcribe(audio_file, timestamp_text, transcription_text): hf_transcriber = pipeline( "automatic-speech-recognition", model=hf_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, chunk_length_s=30, ) waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) waveform = stereo_to_mono(waveform) print(waveform.shape) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) result = hf_transcriber(waveform, return_timestamps=True) print(f"{result['text']=}\n") print(f"{result['chunks']=}\n") # text = result["text"].strip().replace("\n", " ") # text = re.sub(r"(foreign|foreign you|you)\s*$", "", text) chunks, text = hf_process_chunks(result["chunks"]) print(f"{chunks=}\n") print(f"{text=}\n") # Edit components transSoup = BeautifulSoup(transcription_text, "html.parser") transText = transSoup.find(id="transcriptionText") if transText: transText.clear() transText.append(BeautifulSoup(text, "html.parser")) timeSoup = BeautifulSoup(timestamp_text, "html.parser") timeText = timeSoup.find(id="timestampText") if timeText: timeText.clear() timeText.append(BeautifulSoup(chunks, "html.parser")) return str(timeSoup), str(transSoup) def olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text): waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) waveform = stereo_to_mono(waveform) print(waveform.shape) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) options = dict( task="transcribe", language="en", without_timestamps=False, beam_size=5, best_of=5, ) result = olmoasr_model.transcribe(waveform, verbose=False, **options) print(f"{result['text']=}\n") print(f"{result['segments']=}\n") # text = result["text"].strip().replace("\n", " ") # text = re.sub(r"(foreign|foreign you|Thank you for watching!|. you)\s*$", "", text) chunks, text = olmoasr_process_chunks(result["segments"]) print(f"{chunks=}\n") print(f"{text=}\n") # Edit components transSoup = BeautifulSoup(transcription_text, "html.parser") transText = transSoup.find(id="transcriptionText") if transText: transText.clear() transText.append(BeautifulSoup(text, "html.parser")) timeSoup = BeautifulSoup(timestamp_text, "html.parser") timeText = timeSoup.find(id="timestampText") if timeText: timeText.clear() timeText.append(BeautifulSoup(chunks, "html.parser")) return str(timeSoup), str(transSoup) def hf_seq_transcribe(audio_file, timestamp_text, transcription_text): hf_transcriber = pipeline( "automatic-speech-recognition", model=hf_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) waveform = stereo_to_mono(waveform) print(waveform.shape) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) result = hf_transcriber( waveform, return_timestamps=True, ) print(f"{result['text']=}\n") print(f"{result['chunks']=}\n") # text = result["text"].strip().replace("\n", " ") # text = re.sub(r"(foreign|foreign you|you)\s*$", "", text) chunks, text = hf_seq_process_chunks(result["chunks"]) print(f"{text=}\n") print(f"{chunks=}\n") # Edit components transSoup = BeautifulSoup(transcription_text, "html.parser") transText = transSoup.find(id="transcriptionText") if transText: transText.clear() transText.append(BeautifulSoup(text, "html.parser")) timeSoup = BeautifulSoup(timestamp_text, "html.parser") timeText = timeSoup.find(id="timestampText") if timeText: timeText.clear() timeText.append(BeautifulSoup(chunks, "html.parser")) return str(timeSoup), str(transSoup) def main_transcribe(inference_strategy, audio_file, timestamp_text, transcription_text): if inference_strategy == "HuggingFace Chunking": return hf_chunk_transcribe(audio_file, timestamp_text, transcription_text) elif inference_strategy == "OLMoASR Sequential": return olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text) elif inference_strategy == "HuggingFace Sequential": return hf_seq_transcribe(audio_file, timestamp_text, transcription_text) def olmoasr_process_chunks(chunks): processed_chunks = [] processed_chunks_text = [] for chunk in chunks: text = chunk["text"].strip() if not re.match( r"\s*(foreign you|foreign|Thank you for watching!|you there|you)\s*$", text ): if text.strip() == "": continue start = chunk["start"] end = chunk["end"] pattern = r"\n(?!\d+\.\d+\s*-->)" text = re.sub(pattern, "", text) processed_chunks_text.append(text.strip()) processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text}
") else: break print(f"{processed_chunks=}\n") print(f"{processed_chunks_text=}\n") print( re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) if processed_chunks_text else None ) if processed_chunks_text and re.search( r"\s*foreign\s*$", processed_chunks_text[-1] ): processed_chunks_text[-1] = re.sub( r"\s*foreign\s*$", "", processed_chunks_text[-1] ) processed_chunks[-1] = re.sub(r"foreign\s*
", "
", processed_chunks[-1]) return "\n".join(processed_chunks), " ".join(processed_chunks_text) def hf_process_chunks(chunks): processed_chunks = [] processed_chunks_text = [] for chunk in chunks: text = chunk["text"].strip() if not re.match(r"(foreign you|foreign|you there|you)\s*$", text): if text.strip() == "": continue start = chunk["timestamp"][0] end = chunk["timestamp"][1] pattern = r"\n(?!\d+\.\d+\s*-->)" text = re.sub(pattern, "", text) processed_chunks_text.append(text.strip()) processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()}
") else: break print(f"{processed_chunks=}\n") print(f"{processed_chunks_text=}\n") print( re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) if processed_chunks_text else None ) if processed_chunks_text and re.search( r"\s*foreign\s*$", processed_chunks_text[-1] ): processed_chunks_text[-1] = re.sub( r"\s*foreign\s*$", "", processed_chunks_text[-1] ) processed_chunks[-1] = re.sub(r"foreign\s*
", "
", processed_chunks[-1]) return "\n".join(processed_chunks), " ".join(processed_chunks_text) def hf_seq_process_chunks(chunks): processed_chunks = [] processed_chunks_text = [] delta_time = 0.0 global_start = chunks[0]["timestamp"][0] prev_end = -1.0 prev_dur = 0.0 accumulate_ts = False for chunk in chunks: text = chunk["text"].strip() if not re.match(r"\s*(foreign you|foreign|you there|you)\s*$", text): if text.strip() == "": continue start = chunk["timestamp"][0] if start < prev_end: accumulate_ts = True end = chunk["timestamp"][1] if start < prev_end: prev_dur += delta_time # print(f"{prev_dur=}") delta_time = end - global_start # print(f"{delta_time=}") prev_end = end # print(f"{prev_end=}") if accumulate_ts: start += prev_dur if accumulate_ts: end += prev_dur # print(f"{start=}, {end=}, {prev_dur=}") pattern = r"\n(?!\d+\.\d+\s*-->)" text = re.sub(pattern, "", text) processed_chunks_text.append(text.strip()) processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()}
") else: break print(f"{processed_chunks=}\n") print(f"{processed_chunks_text=}\n") print( re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) if processed_chunks_text else None ) if processed_chunks_text and re.search( r"\s*foreign\s*$", processed_chunks_text[-1] ): processed_chunks_text[-1] = re.sub( r"\s*foreign\s*$", "", processed_chunks_text[-1] ) processed_chunks[-1] = re.sub(r"foreign\s*
", "
", processed_chunks[-1]) return "\n".join(processed_chunks), " ".join(processed_chunks_text) original_timestamp_html = """
Timestamp Text
""" original_transcription_html = """
Transcription Text
""" def reset(): return original_timestamp_html, original_transcription_html event_process_js = """ """ demo = gr.Blocks( head=event_process_js, theme=gr.themes.Default(primary_hue="emerald", secondary_hue="green"), ) with demo: audio = gr.Audio(sources=["upload", "microphone"], type="filepath") inf_strategy = gr.Dropdown( label="Inference Strategy", choices=[ "HuggingFace Chunking", "HuggingFace Sequential", "OLMoASR Sequential", ], value="HuggingFace Chunking", multiselect=False, info="Select the inference strategy for transcription.", elem_id="inf_strategy", ) main_transcribe_button = gr.Button( "Transcribe", variant="primary", ) with gr.Row(): timestampText = gr.HTML(original_timestamp_html) transcriptionText = gr.HTML(original_transcription_html) inf_strategy.change( fn=reset, inputs=[], outputs=[timestampText, transcriptionText], ) main_transcribe_button.click( fn=main_transcribe, inputs=[inf_strategy, audio, timestampText, transcriptionText], outputs=[timestampText, transcriptionText], ) demo.launch(share=True)