import gradio as gr from nemo.collections.asr.models import ASRModel import torch import os import spaces from huggingface_hub import hf_hub_download # Check for GPU support and configure appropriately device = "cuda" if torch.cuda.is_available() else "cpu" zero = torch.Tensor([0]).to(device) print(f"Device being used: {zero.device}") # Model name from HuggingFace MODEL_NAMES = [ "NAMAA-Space/EgypTalk-ASR-v2" ] # Cache loaded models LOADED_MODELS = {} # Filename of the NeMo checkpoint inside the HF repo NEMO_FILENAME = "asr-egyptian-nemo-v2.0.nemo" def get_model(model_name): if model_name not in LOADED_MODELS: print(f"Loading model {model_name} on {device}...") # Download the .nemo artifact from HF and restore nemo_path = hf_hub_download(repo_id=model_name, filename=NEMO_FILENAME) model = ASRModel.restore_from(nemo_path) model = model.to(device) model.eval() LOADED_MODELS[model_name] = model print(f"Model {model_name} loaded successfully!") return LOADED_MODELS[model_name] @spaces.GPU(duration=120) def transcribe_and_score(audio): if audio is None: return "" model = get_model(MODEL_NAMES[0]) predictions = model.transcribe([audio]) item = predictions[0] if isinstance(predictions, list) else predictions # Extract plain text regardless of return type if hasattr(item, "text"): text = item.text elif isinstance(item, dict) and "text" in item: text = item["text"] elif isinstance(item, str): text = item else: text = str(item) return text.strip() @spaces.GPU(duration=120) def batch_transcribe(audio_files): if not audio_files: return [] model = get_model(MODEL_NAMES[0]) predictions = model.transcribe(audio_files) texts = [] if isinstance(predictions, list): for p in predictions: if hasattr(p, "text"): t = p.text elif isinstance(p, dict) and "text" in p: t = p["text"] elif isinstance(p, str): t = p else: t = str(p) texts.append(t) else: texts.append(str(predictions)) return [[t.strip()] for t in texts] with gr.Blocks(title="EgypTalk-ASR-v2") as demo: gr.Markdown(""" # EgypTalk-ASR-v2 Upload or record an audio file. This app transcribes audio using EgypTalk-ASR-v2. """) with gr.Tab("Single Test"): with gr.Row(): audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio") transcribe_btn = gr.Button("Transcribe") pred_output = gr.Textbox(label="Transcription") transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output]) with gr.Tab("Batch Test"): gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.") audio_files = gr.Files(label="Audio Files (wav)") batch_btn = gr.Button("Batch Transcribe") preds_output = gr.Dataframe(headers=["Transcription"], label="Results") batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output]) demo.launch(share=True)