Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from nemo.collections.asr.models import ASRModel | |
import torch | |
import os | |
import spaces | |
from huggingface_hub import hf_hub_download | |
# Check for GPU support and configure appropriately | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
zero = torch.Tensor([0]).to(device) | |
print(f"Device being used: {zero.device}") | |
# Model name from HuggingFace | |
MODEL_NAMES = [ | |
"NAMAA-Space/EgypTalk-ASR-v2" | |
] | |
# Cache loaded models | |
LOADED_MODELS = {} | |
# Filename of the NeMo checkpoint inside the HF repo | |
NEMO_FILENAME = "asr-egyptian-nemo-v2.0.nemo" | |
def get_model(model_name): | |
if model_name not in LOADED_MODELS: | |
print(f"Loading model {model_name} on {device}...") | |
# Download the .nemo artifact from HF and restore | |
nemo_path = hf_hub_download(repo_id=model_name, filename=NEMO_FILENAME) | |
model = ASRModel.restore_from(nemo_path) | |
model = model.to(device) | |
model.eval() | |
LOADED_MODELS[model_name] = model | |
print(f"Model {model_name} loaded successfully!") | |
return LOADED_MODELS[model_name] | |
def transcribe_and_score(audio): | |
if audio is None: | |
return "" | |
model = get_model(MODEL_NAMES[0]) | |
predictions = model.transcribe([audio]) | |
item = predictions[0] if isinstance(predictions, list) else predictions | |
# Extract plain text regardless of return type | |
if hasattr(item, "text"): | |
text = item.text | |
elif isinstance(item, dict) and "text" in item: | |
text = item["text"] | |
elif isinstance(item, str): | |
text = item | |
else: | |
text = str(item) | |
return text.strip() | |
def batch_transcribe(audio_files): | |
if not audio_files: | |
return [] | |
model = get_model(MODEL_NAMES[0]) | |
predictions = model.transcribe(audio_files) | |
texts = [] | |
if isinstance(predictions, list): | |
for p in predictions: | |
if hasattr(p, "text"): | |
t = p.text | |
elif isinstance(p, dict) and "text" in p: | |
t = p["text"] | |
elif isinstance(p, str): | |
t = p | |
else: | |
t = str(p) | |
texts.append(t) | |
else: | |
texts.append(str(predictions)) | |
return [[t.strip()] for t in texts] | |
with gr.Blocks(title="EgypTalk-ASR-v2") as demo: | |
gr.Markdown(""" | |
# EgypTalk-ASR-v2 | |
Upload or record an audio file. This app transcribes audio using EgypTalk-ASR-v2. | |
""") | |
with gr.Tab("Single Test"): | |
with gr.Row(): | |
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio") | |
transcribe_btn = gr.Button("Transcribe") | |
pred_output = gr.Textbox(label="Transcription") | |
transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output]) | |
with gr.Tab("Batch Test"): | |
gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.") | |
audio_files = gr.Files(label="Audio Files (wav)") | |
batch_btn = gr.Button("Batch Transcribe") | |
preds_output = gr.Dataframe(headers=["Transcription"], label="Results") | |
batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output]) | |
demo.launch(share=True) |