Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import tempfile | |
| import torch | |
| import gradio as gr | |
| from transformers import pipeline | |
| from pydub import AudioSegment | |
| from pyannote.audio import Pipeline as DiarizationPipeline | |
| import opencc | |
| import spaces # zeroGPU support | |
| from funasr import AutoModel | |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
| # —————— Model Lists —————— | |
| WHISPER_MODELS = [ | |
| # Base Whisper models | |
| "openai/whisper-large-v3-turbo", | |
| "openai/whisper-large-v3", | |
| "openai/whisper-medium", | |
| "openai/whisper-small", | |
| "openai/whisper-base", | |
| "openai/whisper-tiny", | |
| # Community fine-tuned Chinese models | |
| "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW", | |
| "Jingmiao/whisper-small-zh_tw", | |
| "DDTChen/whisper-medium-zh-tw", | |
| "kimbochen/whisper-small-zh-tw", | |
| # ...etc... | |
| ] | |
| SENSEVOICE_MODELS = [ | |
| "FunAudioLLM/SenseVoiceSmall", | |
| "AXERA-TECH/SenseVoice", | |
| "alextomcat/SenseVoiceSmall", | |
| "ChenChenyu/SenseVoiceSmall-finetuned", | |
| "apinge/sensevoice-small", | |
| ] | |
| # —————— Language Options —————— | |
| WHISPER_LANGUAGES = [ | |
| "auto", "af","am","ar","as","az","ba","be","bg","bn","bo", | |
| "br","bs","ca","cs","cy","da","de","el","en","es","et", | |
| "eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi", | |
| "hr","ht","hu","hy","id","is","it","ja","jw","ka","kk", | |
| "km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi", | |
| "mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no", | |
| "oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk", | |
| "sl","sn","so","sq","sr","su","sv","sw","ta","te","tg", | |
| "th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo", | |
| "zh","yue" | |
| ] | |
| SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"] | |
| # —————— Caches —————— | |
| whisper_pipes = {} | |
| sense_models = {} | |
| dar_pipe = None | |
| # Initialize OpenCC converter for simplified to traditional Chinese | |
| converter = opencc.OpenCC('s2t.json') | |
| # —————— Helpers —————— | |
| def get_whisper_pipe(model_id: str, device: int): | |
| key = (model_id, device) | |
| if key not in whisper_pipes: | |
| whisper_pipes[key] = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| device=device, | |
| chunk_length_s=30, | |
| stride_length_s=5, | |
| return_timestamps=False, | |
| ) | |
| return whisper_pipes[key] | |
| def get_sense_model(model_id: str): | |
| if model_id not in sense_models: | |
| device_str = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| sense_models[model_id] = AutoModel( | |
| model=model_id, | |
| vad_model="fsmn-vad", | |
| vad_kwargs={"max_single_segment_time": 300000}, | |
| device=device_str, | |
| hub="hf", | |
| ) | |
| return sense_models[model_id] | |
| def get_diarization_pipe(): | |
| global dar_pipe | |
| if dar_pipe is None: | |
| # Pull token from environment (HF_TOKEN or HUGGINGFACE_TOKEN) | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
| # Try loading latest 3.1 pipeline, fallback to 2.1 on gated model error | |
| try: | |
| dar_pipe = DiarizationPipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=token or True | |
| ) | |
| except Exception as e: | |
| print(f"Failed to load pyannote/speaker-diarization-3.1: {e}\nFalling back to pyannote/[email protected].") | |
| dar_pipe = DiarizationPipeline.from_pretrained( | |
| "pyannote/[email protected]", | |
| use_auth_token=token or True | |
| ) | |
| return dar_pipe | |
| # —————— Transcription Functions —————— | |
| def transcribe_whisper(model_id: str, | |
| language: str, | |
| audio_path: str, | |
| device_sel: str, | |
| enable_diar: bool): | |
| # select device: 0 for GPU, -1 for CPU | |
| use_gpu = (device_sel == "GPU" and torch.cuda.is_available()) | |
| device = 0 if use_gpu else -1 | |
| pipe = get_whisper_pipe(model_id, device) | |
| # full transcription | |
| result = (pipe(audio_path) if language == "auto" | |
| else pipe(audio_path, generate_kwargs={"language": language})) | |
| transcript = result.get("text", "").strip() | |
| # convert simplified Chinese to traditional | |
| transcript = converter.convert(transcript) | |
| diar_text = "" | |
| # optional speaker diarization | |
| if enable_diar: | |
| diarizer = get_diarization_pipe() | |
| diary = diarizer(audio_path) | |
| snippets = [] | |
| for turn, _, speaker in diary.itertracks(yield_label=True): | |
| start_ms, end_ms = int(turn.start*1000), int(turn.end*1000) | |
| segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| segment.export(tmp.name, format="wav") | |
| seg_out = (pipe(tmp.name) if language == "auto" | |
| else pipe(tmp.name, generate_kwargs={"language": language})) | |
| os.unlink(tmp.name) | |
| text = seg_out.get("text", "").strip() | |
| # convert simplified Chinese to traditional | |
| text = converter.convert(text) | |
| snippets.append(f"[{speaker}] {text}") | |
| diar_text = "\n".join(snippets) | |
| return transcript, diar_text | |
| def transcribe_sense(model_id: str, | |
| language: str, | |
| audio_path: str, | |
| enable_punct: bool, | |
| enable_diar: bool): | |
| model = get_sense_model(model_id) | |
| # no diarization | |
| if not enable_diar: | |
| segs = model.generate( | |
| input=audio_path, | |
| cache={}, | |
| language=language, | |
| use_itn=True, | |
| batch_size_s=300, | |
| merge_vad=True, | |
| merge_length_s=15, | |
| ) | |
| text = rich_transcription_postprocess(segs[0]['text']) | |
| if not enable_punct: | |
| text = re.sub(r"[^\w\s]", "", text) | |
| # convert simplified Chinese to traditional | |
| text = converter.convert(text) | |
| return text, "" | |
| # with diarization | |
| diarizer = get_diarization_pipe() | |
| diary = diarizer(audio_path) | |
| snippets = [] | |
| for turn, _, speaker in diary.itertracks(yield_label=True): | |
| start_ms, end_ms = int(turn.start*1000), int(turn.end*1000) | |
| segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| segment.export(tmp.name, format="wav") | |
| segs = model.generate( | |
| input=tmp.name, | |
| cache={}, | |
| language=language, | |
| use_itn=True, | |
| batch_size_s=300, | |
| merge_vad=False, | |
| merge_length_s=0, | |
| ) | |
| os.unlink(tmp.name) | |
| txt = rich_transcription_postprocess(segs[0]['text']) | |
| if not enable_punct: | |
| txt = re.sub(r"[^\w\s]", "", txt) | |
| # convert simplified Chinese to traditional | |
| txt = converter.convert(txt) | |
| snippets.append(f"[{speaker}] {txt}") | |
| full = rich_transcription_postprocess(model.generate( | |
| input=audio_path, | |
| cache={}, | |
| language=language, | |
| use_itn=True, | |
| batch_size_s=300, | |
| merge_vad=True, | |
| merge_length_s=15 | |
| )[0]['text']) | |
| if not enable_punct: | |
| full = re.sub(r"[^\w\s]", "", full) | |
| full = converter.convert(full) | |
| return full, "\n".join(snippets) | |
| # —————— Gradio UI —————— | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Diarization with Simplified→Traditional Chinese)") | |
| audio_input = gr.Audio(sources=["upload","microphone"], type="filepath", label="Audio Input") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Whisper ASR") | |
| whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model") | |
| whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language") | |
| device_radio = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device") | |
| diar_check = gr.Checkbox(label="Enable Diarization", value=True) | |
| btn_w = gr.Button("Transcribe with Whisper") | |
| out_w = gr.Textbox(label="Transcript") | |
| out_w_d = gr.Textbox(label="Diarized Transcript") | |
| btn_w.click(fn=transcribe_whisper, | |
| inputs=[whisper_dd, whisper_lang, audio_input, device_radio, diar_check], | |
| outputs=[out_w, out_w_d]) | |
| with gr.Column(): | |
| gr.Markdown("### FunASR SenseVoice ASR") | |
| sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model") | |
| sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language") | |
| punct_chk = gr.Checkbox(label="Enable Punctuation", value=True) | |
| diar_s_chk = gr.Checkbox(label="Enable Diarization", value=True) | |
| btn_s = gr.Button("Transcribe with SenseVoice") | |
| out_s = gr.Textbox(label="Transcript") | |
| out_s_d = gr.Textbox(label="Diarized Transcript") | |
| btn_s.click(fn=transcribe_sense, | |
| inputs=[sense_dd, sense_lang, audio_input, punct_chk, diar_s_chk], | |
| outputs=[out_s, out_s_d]) | |
| if __name__ == "__main__": | |
| demo.launch() | |