import gradio as gr from nemo.core import ModelPT import torch import os import spaces # Check for GPU support and configure appropriately device = "cuda" if torch.cuda.is_available() else "cpu" zero = torch.Tensor([0]).to(device) print(f"Device being used: {zero.device}") # Model name from HuggingFace MODEL_NAMES = [ "NAMAA-Space/EgypTalk-ASR-v2" ] # Cache loaded models LOADED_MODELS = {} def get_model(model_name): if model_name not in LOADED_MODELS: print(f"Loading model {model_name} on {device}...") # Load model from HuggingFace using ModelPT model = ModelPT.from_pretrained(model_name) model = model.to(device) model.eval() LOADED_MODELS[model_name] = model print(f"Model {model_name} loaded successfully!") return LOADED_MODELS[model_name] @spaces.GPU(duration=120) def transcribe_and_score(audio): if audio is None: return "" model = get_model(MODEL_NAMES[0]) # Use the correct transcribe API predictions = model.transcribe([audio]) pred = predictions[0] if isinstance(predictions, list) else predictions if not isinstance(pred, str): pred = str(pred) return pred.strip() @spaces.GPU(duration=120) def batch_transcribe(audio_files): if not audio_files: return [] model = get_model(MODEL_NAMES[0]) # Use the correct transcribe API for batch predictions = model.transcribe(audio_files) if isinstance(predictions, list): texts = [p if isinstance(p, str) else str(p) for p in predictions] else: texts = [str(predictions)] # Return as rows for a single-column dataframe return [[t.strip()] for t in texts] with gr.Blocks(title="EgypTalk-ASR-v2") as demo: gr.Markdown(""" # EgypTalk-ASR-v2 Upload an audio file. This app transcribes audio using EgypTalk-ASR-v2. """) with gr.Tab("Single Test"): with gr.Row(): audio_input = gr.Audio(type="filepath", label="Audio File") transcribe_btn = gr.Button("Transcribe") pred_output = gr.Textbox(label="Transcription") transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output]) with gr.Tab("Batch Test"): gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.") audio_files = gr.Files(label="Audio Files (wav)") batch_btn = gr.Button("Batch Transcribe") preds_output = gr.Dataframe(headers=["Transcription"], label="Results") batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output]) demo.launch(share=True)