from pathlib import Path import gradio as gr import librosa import numpy as np import spaces import torch from loguru import logger from transformers import AutoModelForCTC, AutoProcessor REPO_ID_LIST = [ "AndrewMcDowell/wav2vec2-xls-r-1b-japanese-hiragana-katakana", "natsuo/ja_hiragana", "prj-beatrice/japanese-hubert-base-phoneme-ctc", "prj-beatrice/japanese-hubert-base-phoneme-ctc-v2", "slplab/wav2vec2-xls-r-300m-japanese-hiragana", "snu-nia-12/wav2vec2-xls-r-300m_nia12_phone-hiragana_japanese", "thunninoi/wav2vec2-japanese-hiragana-vtuber", "vumichien/wav2vec2-large-xlsr-japanese-hiragana", ] DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Loading models...") MODEL_DICT = {} for repo_id in REPO_ID_LIST: model = AutoModelForCTC.from_pretrained(repo_id).to(DEVICE) processor = AutoProcessor.from_pretrained(repo_id) MODEL_DICT[repo_id] = (model, processor) logger.info(f"Loaded model from {repo_id}") logger.success("All models loaded successfully") @spaces.GPU @torch.inference_mode() def transcribe(audio_path: Path) -> list[str]: duration = librosa.get_duration(path=audio_path) logger.info(f"audio: {Path(audio_path).name}, duration: {duration:.2f} seconds") if duration > 30: raise gr.Error("Audio duration exceeds 30 seconds") y, sr = librosa.load(audio_path, sr=16000) y = np.concatenate([np.zeros(sr), y, np.zeros(sr // 2)]) results = [] for repo_id in REPO_ID_LIST: logger.info(f"Transcribing with model: {repo_id}") model, processor = MODEL_DICT[repo_id] inputs = processor(y, sampling_rate=sr, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) predicted_ids = outputs.logits.argmax(-1) phonemes = processor.decode( predicted_ids[0], spaces_between_special_tokens=True ) logger.info(f"Transcription for {repo_id}: {phonemes}") results.append(phonemes) return results md = """ # Japanese Phoneme ASR CTC Comparison 日本語の音素単位のCTC ASRモデルの比較デモ。 """ for repo_id in REPO_ID_LIST: md += f"- [{repo_id}](https://huggingface.co/{repo_id})\n" with gr.Blocks() as demo: gr.Markdown(md) audio_input = gr.Audio(label="Audio Input", type="filepath") transcribe_button = gr.Button("Transcribe") outputs = [] with gr.Column(): for repo_id in REPO_ID_LIST: output_text = gr.Textbox(label=f"{repo_id}") outputs.append(output_text) transcribe_button.click(transcribe, inputs=[audio_input], outputs=outputs) demo.launch()