litagin's picture
Use float16 for gpu
5f43bee
raw
history blame
4.44 kB
import os
import time
import warnings
from pathlib import Path
import gradio as gr
import huggingface_hub
import librosa
import spaces
import torch
from loguru import logger
from transformers import pipeline
warnings.filterwarnings("ignore")
huggingface_hub.login(token=os.getenv("HF_TOKEN"))
is_hf = os.getenv("SYSTEM") == "spaces"
generate_kwargs = {
"language": "Japanese",
"do_sample": False,
"num_beams": 1,
"no_repeat_ngram_size": 0,
"max_new_tokens": 64,
}
model_dict = {
"whisper-large-v3-turbo": "openai/whisper-large-v3-turbo",
"kotoba-whisper-v2.0": "kotoba-tech/kotoba-whisper-v2.0",
"anime-whisper": "litagin/anime-whisper",
}
logger.info("Initializing pipelines...")
pipe_dict = {
k: pipeline(
"automatic-speech-recognition",
model=v,
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
for k, v in model_dict.items()
}
logger.success("Pipelines initialized!")
@spaces.GPU
def transcribe_common(audio: str, model: str) -> str:
if not audio:
return "No audio file"
filename = Path(audio).name
logger.info(f"Model: {model}")
logger.info(f"Audio: {filename}")
# Read and resample audio to 16kHz
try:
y, sr = librosa.load(audio, mono=True, sr=16000)
except Exception as e:
# First convert to wav if librosa cannot read the file
logger.error(f"Error reading file: {e}")
from pydub import AudioSegment
audio = AudioSegment.from_file(audio)
audio.export("temp.wav", format="wav")
y, sr = librosa.load("temp.wav", mono=True, sr=16000)
# Get duration of audio
duration = librosa.get_duration(y=y, sr=sr)
logger.info(f"Duration: {duration:.2f}s")
if duration > 15:
logger.error(f"Audio too long, limit is 15 seconds, got {duration:.2f}s")
return f"Audio too long, limit is 15 seconds, got {duration:.2f}s"
start_time = time.time()
result = pipe_dict[model](y, generate_kwargs=generate_kwargs)["text"]
end_time = time.time()
logger.success(f"Finished in {end_time - start_time:.2f}s\n{result}")
return result
def transcribe_others(audio) -> tuple[str, str]:
result_v3 = transcribe_common(audio, "whisper-large-v3-turbo")
result_kotoba_v2 = transcribe_common(audio, "kotoba-whisper-v2.0")
return result_v3, result_kotoba_v2
def transcribe_anime_whisper(audio) -> str:
return transcribe_common(audio, "anime-whisper")
initial_md = """
# Anime-Whisper Demo
[**Anime Whisper**](https://huggingface.co/litagin/anime-whisper): 5千時間以上のアニメ調セリフと台本でファインチューニングされた音声認識モデルです。
- ベースモデル: [kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0)
- デモでは**音声は15秒まで**しか受け付けません
- 日本語のみ対応 (Japanese only)
- 比較のために [openai/whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) と [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0) も用意しています
pipeに渡しているkwargsは以下:
```python
generate_kwargs = {
"language": "Japanese",
"do_sample": False,
"num_beams": 1,
"no_repeat_ngram_size": 0,
"max_new_tokens": 64, # 結果が長いときは途中で打ち切る
}
```
"""
with gr.Blocks() as app:
gr.Markdown(initial_md)
audio = gr.Audio(type="filepath")
with gr.Row():
with gr.Column():
gr.Markdown("### Anime-Whisper")
button_galgame = gr.Button("Transcribe with Anime-Whisper")
output_galgame = gr.Textbox(label="Result")
gr.Markdown("### Comparison")
button_others = gr.Button("Transcribe with other models")
with gr.Row():
with gr.Column():
gr.Markdown("### Whisper-Large-V3-Turbo")
output_v3 = gr.Textbox(label="Result")
with gr.Column():
gr.Markdown("### Kotoba-Whisper-V2.0")
output_kotoba_v2 = gr.Textbox(label="Result")
button_galgame.click(
transcribe_anime_whisper,
inputs=[audio],
outputs=[output_galgame],
)
button_others.click(
transcribe_others,
inputs=[audio],
outputs=[output_v3, output_kotoba_v2],
)
app.launch(inbrowser=True)