TWASR / app.py
JacobLinCool's picture
fix: reorder model choices and update example files to use PHI model
c3b43a8
import spaces
import gradio as gr
import logging
from pathlib import Path
import base64
from model import (
MODEL_ID as WHISPER_MODEL_ID,
PHI_MODEL_ID,
transcribe_audio_local,
transcribe_audio_phi,
preload_models,
)
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Constants
EXAMPLES_DIR = Path("./examples")
MODEL_CHOICES = {
PHI_MODEL_ID: "Phi-4 Model",
WHISPER_MODEL_ID: "Whisper Model",
}
EXAMPLE_FILES = [
[str(EXAMPLES_DIR / "audio1.mp3"), PHI_MODEL_ID],
[str(EXAMPLES_DIR / "audio2.mp3"), PHI_MODEL_ID],
]
def read_file_as_base64(file_path: str) -> str:
"""
Read a file and encode it as base64.
Args:
file_path: Path to the file to read
Returns:
Base64 encoded string of file contents
"""
try:
with open(file_path, "rb") as f:
return base64.b64encode(f.read()).decode()
except Exception as e:
logger.error(f"Failed to read file {file_path}: {str(e)}")
raise
def combined_transcription(audio: str, model_choice: str) -> str:
"""
Transcribe audio using the selected model.
Args:
audio: Path to audio file
model_choice: Full model ID to use for transcription
Returns:
Transcription text
"""
if not audio:
return "Please provide an audio file to transcribe."
try:
if model_choice == PHI_MODEL_ID:
return transcribe_audio_phi(audio)
elif model_choice == WHISPER_MODEL_ID:
return transcribe_audio_local(audio)
else:
logger.error(f"Unknown model choice: {model_choice}")
return f"Error: Unknown model {model_choice}"
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
return f"Error during transcription: {str(e)}"
def create_demo() -> gr.Blocks:
"""Create and configure the Gradio demo interface"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# TWASR: Chinese (Taiwan) Automatic Speech Recognition")
gr.Markdown(
"Upload an audio file or record your voice to transcribe it to text."
)
gr.Markdown(
"⚠️ First load may take a while to initialize the model, following requests will be faster."
)
with gr.Row():
audio_input = gr.Audio(
label="Audio Input", type="filepath", show_download_button=True
)
with gr.Column():
model_choice = gr.Dropdown(
label="Select Model",
choices=list(MODEL_CHOICES.keys()),
value=PHI_MODEL_ID,
info="Select the model for transcription",
)
text_output = gr.Textbox(label="Transcription Output", lines=5)
with gr.Row():
transcribe_button = gr.Button("🎯 Transcribe", variant="primary")
clear_button = gr.Button("🧹 Clear")
transcribe_button.click(
fn=combined_transcription,
inputs=[audio_input, model_choice],
outputs=[text_output],
show_progress=True,
)
clear_button.click(
fn=lambda: (None, ""),
inputs=[],
outputs=[audio_input, text_output],
)
gr.Examples(
examples=EXAMPLE_FILES,
inputs=[audio_input, model_choice],
outputs=[text_output],
fn=combined_transcription,
cache_examples=True,
cache_mode="lazy",
run_on_click=True,
)
gr.Markdown("### Model Information")
with gr.Accordion("Model Details", open=False):
for model_id, model_name in MODEL_CHOICES.items():
gr.Markdown(
f"**{model_name}:** [{model_id}](https://huggingface.co/{model_id})"
)
return demo
if __name__ == "__main__":
# Preload models before starting the app to reduce cold start time
logger.info("Preloading models to reduce cold start time")
preload_models()
demo = create_demo()
demo.launch(share=False)