import os import tempfile import spaces import torch import torchaudio from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor from qwen_omni_utils import process_mm_info from opencc import OpenCC import gradio as gr from pyannote.audio import Pipeline as DiarizationPipeline from pydub import AudioSegment, effects # Converter from Simplified to Traditional Chinese cc = OpenCC("s2t") # Define available model IDs MODEL_IDS = { "3B": "Qwen/Qwen2.5-Omni-3B", "7B": "Qwen/Qwen2.5-Omni-7B" } # Caches for loaded models and processors _models = {} _processors = {} def get_model_and_processor(size: str): """ Load and cache the model and processor for the given size ("3B" or "7B"). """ if size not in _models: model_id = MODEL_IDS[size] # Load model with device_map="auto" for ZeroGPU compatibility m = Qwen2_5OmniForConditionalGeneration.from_pretrained( model_id, torch_dtype="auto", device_map="auto" ) m.disable_talker() m.eval() p = Qwen2_5OmniProcessor.from_pretrained(model_id) _models[size] = m _processors[size] = p return _models[size], _processors[size] # Cache the diarization pipeline so we only load it once _diar_pipe = None def get_diarization_pipe(): global _diar_pipe if _diar_pipe is None: hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") try: _diar_pipe = DiarizationPipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=hf_token or True ) except Exception: _diar_pipe = DiarizationPipeline.from_pretrained( "pyannote/speaker-diarization@2.1", use_auth_token=hf_token or True ) return _diar_pipe # Format a list of "[SPEAKER_X] text" snippets into colored HTML def format_diarization_html(snippets): palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"] speaker_colors = {} html_lines = [] last_spk = None for s in snippets: if s.startswith("[") and "]" in s: spk, txt = s[1:].split("]", 1) spk, txt = spk.strip(), txt.strip() else: spk, txt = "", s.strip() if not txt: continue if spk not in speaker_colors: speaker_colors[spk] = palette[len(speaker_colors) % len(palette)] color = speaker_colors[spk] if spk == last_spk: display = txt else: display = f"{spk}: {txt}" last_spk = spk html_lines.append( f"

{display}

" ) return "
" + "".join(html_lines) + "
" def _strip_prompts(full_text: str) -> str: """ Remove system/user/assistant prefixes so only the actual ASR transcript remains. """ marker = "assistant" if marker in full_text: return full_text.split(marker, 1)[1].strip() else: return full_text.strip() @spaces.GPU def run_asr( audio_path: str, user_prompt: str, model_size: str ): # Validate inputs if not audio_path: yield format_diarization_html(["⚠️ Please upload an audio file first."]) return # Load diarization model onto GPU/CPU diarizer = get_diarization_pipe() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") diarizer.to(device) # Load waveform + sample rate and push to device waveform, sample_rate = torchaudio.load(audio_path) waveform = waveform.to(device) # Get appropriate Qwen model & processor based on selection model, processor = get_model_and_processor(model_size) model.to(device) # Run diarization to get speaker turns diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}) snippets = [] # For each speaker turn, slice audio, transcribe, convert, accumulate for turn, _, speaker in diary.itertracks(yield_label=True): start_ms = int(turn.start * 1000) end_ms = int(turn.end * 1000) # Extract the segment, normalize, export to temp file segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] segment = effects.normalize(segment) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: segment.export(tmp.name, format="wav") tmp_path = tmp.name # Build messages for this segment sys_prompt = ( "You are a speech recognition model." ) messages = [ {"role": "system", "content": [{"type": "text", "text": sys_prompt}]}, { "role": "user", "content": [ {"type": "audio", "audio": tmp_path}, {"type": "text", "text": user_prompt} ], }, ] # Apply chat template (no tokenization yet) text_input = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Preprocess audio (and any images/videos, though here only audio) audios, images, videos = process_mm_info(messages, use_audio_in_video=True) # Tokenize & move tensors inputs = processor( text=text_input, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=True ) inputs = inputs.to(model.device).to(model.dtype) # Generate for this snippet output_tokens = model.generate( **inputs, use_audio_in_video=True, return_audio=False, thinker_max_new_tokens=512, thinker_do_sample=False ) # Decode (system+user+assistant) full_decoded = processor.batch_decode( output_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0].strip() # Strip prefixes to isolate ASR transcript asr_text = _strip_prompts(full_decoded) # Convert to Traditional Chinese asr_text = cc.convert(asr_text) # Append with speaker label snippets.append(f"[{speaker}] {asr_text}") # Yield updated HTML so Gradio can stream yield format_diarization_html(snippets) # Clean up temp file for this segment os.unlink(tmp_path) return # ----------------------------- # Gradio UI # ----------------------------- DEMO_CSS = """ .diar { padding: 0.5rem; color: #f1f1f1; font-family: monospace; font-size: 0.9rem; } """ with gr.Blocks(css=DEMO_CSS) as demo: gr.Markdown("## Qwen2.5-Omni ASR with Speaker Diarization & S2T Conversion (ZeroGPU)") with gr.Row(): audio_input = gr.Audio( label="Upload Audio (WAV/MP3/…)", type="filepath" ) user_input = gr.Textbox( label="User Prompt", value="Transcribe the attached audio to text with punctuation." ) model_selector = gr.Radio( choices=["3B", "7B"], value="7B", label="Model Size" ) # Example audio files example_list = [ ["audio/ads.mp3"], ["audio/meeting.mp3"], ["audio/news.mp3"] ] gr.Examples( examples=example_list, inputs=[audio_input], examples_per_page=3, label="Try one of these audio files ⤵︎" ) submit_btn = gr.Button("Transcribe") diarized_output = gr.HTML( label="Speaker-Diarized Transcript (Traditional Chinese)", elem_classes=["diar"] ) submit_btn.click( fn=run_asr, inputs=[audio_input, user_input, model_selector], outputs=diarized_output ) if __name__ == "__main__": demo.queue() demo.launch()