litagin's picture
Add prj-beatrice/japanese-hubert-base-phoneme-ctc-v2
22efcb6
raw
history blame
2.69 kB
from pathlib import Path
import gradio as gr
import librosa
import numpy as np
import spaces
import torch
from loguru import logger
from transformers import AutoModelForCTC, AutoProcessor
REPO_ID_LIST = [
"AndrewMcDowell/wav2vec2-xls-r-1b-japanese-hiragana-katakana",
"natsuo/ja_hiragana",
"prj-beatrice/japanese-hubert-base-phoneme-ctc",
"prj-beatrice/japanese-hubert-base-phoneme-ctc-v2",
"slplab/wav2vec2-xls-r-300m-japanese-hiragana",
"snu-nia-12/wav2vec2-xls-r-300m_nia12_phone-hiragana_japanese",
"thunninoi/wav2vec2-japanese-hiragana-vtuber",
"vumichien/wav2vec2-large-xlsr-japanese-hiragana",
]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading models...")
MODEL_DICT = {}
for repo_id in REPO_ID_LIST:
model = AutoModelForCTC.from_pretrained(repo_id).to(DEVICE)
processor = AutoProcessor.from_pretrained(repo_id)
MODEL_DICT[repo_id] = (model, processor)
logger.info(f"Loaded model from {repo_id}")
logger.success("All models loaded successfully")
@spaces.GPU
@torch.inference_mode()
def transcribe(audio_path: Path) -> list[str]:
duration = librosa.get_duration(path=audio_path)
logger.info(f"audio: {Path(audio_path).name}, duration: {duration:.2f} seconds")
if duration > 30:
raise gr.Error("Audio duration exceeds 30 seconds")
y, sr = librosa.load(audio_path, sr=16000)
y = np.concatenate([np.zeros(sr), y, np.zeros(sr // 2)])
results = []
for repo_id in REPO_ID_LIST:
logger.info(f"Transcribing with model: {repo_id}")
model, processor = MODEL_DICT[repo_id]
inputs = processor(y, sampling_rate=sr, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
predicted_ids = outputs.logits.argmax(-1)
phonemes = processor.decode(
predicted_ids[0], spaces_between_special_tokens=True
)
logger.info(f"Transcription for {repo_id}: {phonemes}")
results.append(phonemes)
return results
md = """
# Japanese Phoneme ASR CTC Comparison
日本語の音素単位のCTC ASRモデルの比較デモ。
"""
for repo_id in REPO_ID_LIST:
md += f"- [{repo_id}](https://huggingface.co/{repo_id})\n"
with gr.Blocks() as demo:
gr.Markdown(md)
audio_input = gr.Audio(label="Audio Input", type="filepath")
transcribe_button = gr.Button("Transcribe")
outputs = []
with gr.Column():
for repo_id in REPO_ID_LIST:
output_text = gr.Textbox(label=f"{repo_id}")
outputs.append(output_text)
transcribe_button.click(transcribe, inputs=[audio_input], outputs=outputs)
demo.launch()