import gradio as gr from gradio_rich_textbox import RichTextbox import torchaudio import re import librosa import torch import numpy as np from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from whisper.normalizers import EnglishTextNormalizer from whisper import audio, DecodingOptions from whisper.tokenizer import get_tokenizer from whisper.decoding import detect_language from olmoasr import load_model from bs4 import BeautifulSoup device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 hf_model_path = "checkpoints/medium_hf_demo" olmoasr_ckpt = ( "checkpoints/eval_latesttrain_00524288_medium_fsdp-train_grad-acc_bfloat16_inf.pt" ) hf_model = AutoModelForSpeechSeq2Seq.from_pretrained( hf_model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) hf_model.to(device).eval() processor = AutoProcessor.from_pretrained(hf_model_path) olmoasr_model = load_model( name=olmoasr_ckpt, device=device, inference=True, in_memory=True ) olmoasr_model.to(device).eval() normalizer = EnglishTextNormalizer() def stereo_to_mono(waveform): # Check if the waveform is stereo if waveform.shape[0] == 2: # Average the two channels to convert to mono mono_waveform = np.mean(waveform, axis=0) return mono_waveform else: # If already mono, return as is return waveform def hf_chunk_transcribe(audio_file, timestamp_text, transcription_text): hf_transcriber = pipeline( "automatic-speech-recognition", model=hf_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, chunk_length_s=30, ) waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) waveform = stereo_to_mono(waveform) print(waveform.shape) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) result = hf_transcriber(waveform, return_timestamps=True) print(f"{result['text']=}\n") print(f"{result['chunks']=}\n") # text = result["text"].strip().replace("\n", " ") # text = re.sub(r"(foreign|foreign you|you)\s*$", "", text) chunks, text = hf_process_chunks(result["chunks"]) print(f"{chunks=}\n") print(f"{text=}\n") # Edit components transSoup = BeautifulSoup(transcription_text, "html.parser") transText = transSoup.find(id="transcriptionText") if transText: transText.clear() transText.append(BeautifulSoup(text, "html.parser")) timeSoup = BeautifulSoup(timestamp_text, "html.parser") timeText = timeSoup.find(id="timestampText") if timeText: timeText.clear() timeText.append(BeautifulSoup(chunks, "html.parser")) return str(timeSoup), str(transSoup) def olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text): waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) waveform = stereo_to_mono(waveform) print(waveform.shape) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) options = dict( task="transcribe", language="en", without_timestamps=False, beam_size=5, best_of=5, ) result = olmoasr_model.transcribe(waveform, verbose=False, **options) print(f"{result['text']=}\n") print(f"{result['segments']=}\n") # text = result["text"].strip().replace("\n", " ") # text = re.sub(r"(foreign|foreign you|Thank you for watching!|. you)\s*$", "", text) chunks, text = olmoasr_process_chunks(result["segments"]) print(f"{chunks=}\n") print(f"{text=}\n") # Edit components transSoup = BeautifulSoup(transcription_text, "html.parser") transText = transSoup.find(id="transcriptionText") if transText: transText.clear() transText.append(BeautifulSoup(text, "html.parser")) timeSoup = BeautifulSoup(timestamp_text, "html.parser") timeText = timeSoup.find(id="timestampText") if timeText: timeText.clear() timeText.append(BeautifulSoup(chunks, "html.parser")) return str(timeSoup), str(transSoup) def hf_seq_transcribe(audio_file, timestamp_text, transcription_text): hf_transcriber = pipeline( "automatic-speech-recognition", model=hf_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False) waveform = stereo_to_mono(waveform) print(waveform.shape) if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) result = hf_transcriber( waveform, return_timestamps=True, ) print(f"{result['text']=}\n") print(f"{result['chunks']=}\n") # text = result["text"].strip().replace("\n", " ") # text = re.sub(r"(foreign|foreign you|you)\s*$", "", text) chunks, text = hf_seq_process_chunks(result["chunks"]) print(f"{text=}\n") print(f"{chunks=}\n") # Edit components transSoup = BeautifulSoup(transcription_text, "html.parser") transText = transSoup.find(id="transcriptionText") if transText: transText.clear() transText.append(BeautifulSoup(text, "html.parser")) timeSoup = BeautifulSoup(timestamp_text, "html.parser") timeText = timeSoup.find(id="timestampText") if timeText: timeText.clear() timeText.append(BeautifulSoup(chunks, "html.parser")) return str(timeSoup), str(transSoup) def main_transcribe(inference_strategy, audio_file, timestamp_text, transcription_text): if inference_strategy == "HuggingFace Chunking": return hf_chunk_transcribe(audio_file, timestamp_text, transcription_text) elif inference_strategy == "OLMoASR Sequential": return olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text) elif inference_strategy == "HuggingFace Sequential": return hf_seq_transcribe(audio_file, timestamp_text, transcription_text) def olmoasr_process_chunks(chunks): processed_chunks = [] processed_chunks_text = [] for chunk in chunks: text = chunk["text"].strip() if not re.match( r"\s*(foreign you|foreign|Thank you for watching!|you there|you)\s*$", text ): if text.strip() == "": continue start = chunk["start"] end = chunk["end"] pattern = r"\n(?!\d+\.\d+\s*-->)" text = re.sub(pattern, "", text) processed_chunks_text.append(text.strip()) processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text}
") else: break print(f"{processed_chunks=}\n") print(f"{processed_chunks_text=}\n") print( re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) if processed_chunks_text else None ) if processed_chunks_text and re.search( r"\s*foreign\s*$", processed_chunks_text[-1] ): processed_chunks_text[-1] = re.sub( r"\s*foreign\s*$", "", processed_chunks_text[-1] ) processed_chunks[-1] = re.sub(r"foreign\s*
", "
", processed_chunks[-1]) return "\n".join(processed_chunks), " ".join(processed_chunks_text) def hf_process_chunks(chunks): processed_chunks = [] processed_chunks_text = [] for chunk in chunks: text = chunk["text"].strip() if not re.match(r"(foreign you|foreign|you there|you)\s*$", text): if text.strip() == "": continue start = chunk["timestamp"][0] end = chunk["timestamp"][1] pattern = r"\n(?!\d+\.\d+\s*-->)" text = re.sub(pattern, "", text) processed_chunks_text.append(text.strip()) processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()}
") else: break print(f"{processed_chunks=}\n") print(f"{processed_chunks_text=}\n") print( re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) if processed_chunks_text else None ) if processed_chunks_text and re.search( r"\s*foreign\s*$", processed_chunks_text[-1] ): processed_chunks_text[-1] = re.sub( r"\s*foreign\s*$", "", processed_chunks_text[-1] ) processed_chunks[-1] = re.sub(r"foreign\s*
", "
", processed_chunks[-1]) return "\n".join(processed_chunks), " ".join(processed_chunks_text) def hf_seq_process_chunks(chunks): processed_chunks = [] processed_chunks_text = [] delta_time = 0.0 global_start = chunks[0]["timestamp"][0] prev_end = -1.0 prev_dur = 0.0 accumulate_ts = False for chunk in chunks: text = chunk["text"].strip() if not re.match(r"\s*(foreign you|foreign|you there|you)\s*$", text): if text.strip() == "": continue start = chunk["timestamp"][0] if start < prev_end: accumulate_ts = True end = chunk["timestamp"][1] if start < prev_end: prev_dur += delta_time # print(f"{prev_dur=}") delta_time = end - global_start # print(f"{delta_time=}") prev_end = end # print(f"{prev_end=}") if accumulate_ts: start += prev_dur if accumulate_ts: end += prev_dur # print(f"{start=}, {end=}, {prev_dur=}") pattern = r"\n(?!\d+\.\d+\s*-->)" text = re.sub(pattern, "", text) processed_chunks_text.append(text.strip()) processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()}
") else: break print(f"{processed_chunks=}\n") print(f"{processed_chunks_text=}\n") print( re.search(r"\s*foreign\s*$", processed_chunks_text[-1]) if processed_chunks_text else None ) if processed_chunks_text and re.search( r"\s*foreign\s*$", processed_chunks_text[-1] ): processed_chunks_text[-1] = re.sub( r"\s*foreign\s*$", "", processed_chunks_text[-1] ) processed_chunks[-1] = re.sub(r"foreign\s*
", "
", processed_chunks[-1]) return "\n".join(processed_chunks), " ".join(processed_chunks_text) original_timestamp_html = """
Timestamp Text
""" original_transcription_html = """
Transcription Text
""" def reset(): return original_timestamp_html, original_transcription_html event_process_js = """ """ demo = gr.Blocks( head=event_process_js, theme=gr.themes.Default(primary_hue="emerald", secondary_hue="green"), ) with demo: audio = gr.Audio(sources=["upload", "microphone"], type="filepath") inf_strategy = gr.Dropdown( label="Inference Strategy", choices=[ "HuggingFace Chunking", "HuggingFace Sequential", "OLMoASR Sequential", ], value="HuggingFace Chunking", multiselect=False, info="Select the inference strategy for transcription.", elem_id="inf_strategy", ) main_transcribe_button = gr.Button( "Transcribe", variant="primary", ) with gr.Row(): timestampText = gr.HTML(original_timestamp_html) transcriptionText = gr.HTML(original_transcription_html) inf_strategy.change( fn=reset, inputs=[], outputs=[timestampText, transcriptionText], ) main_transcribe_button.click( fn=main_transcribe, inputs=[inf_strategy, audio, timestampText, transcriptionText], outputs=[timestampText, transcriptionText], ) demo.launch(share=True)