Omartificial-Intelligence-Space's picture
Update app.py
59b71d1 verified
import gradio as gr
from nemo.collections.asr.models import ASRModel
import torch
import os
import spaces
from huggingface_hub import hf_hub_download
# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")
# Model name from HuggingFace
MODEL_NAMES = [
"NAMAA-Space/EgypTalk-ASR-v2"
]
# Cache loaded models
LOADED_MODELS = {}
# Filename of the NeMo checkpoint inside the HF repo
NEMO_FILENAME = "asr-egyptian-nemo-v2.0.nemo"
def get_model(model_name):
if model_name not in LOADED_MODELS:
print(f"Loading model {model_name} on {device}...")
# Download the .nemo artifact from HF and restore
nemo_path = hf_hub_download(repo_id=model_name, filename=NEMO_FILENAME)
model = ASRModel.restore_from(nemo_path)
model = model.to(device)
model.eval()
LOADED_MODELS[model_name] = model
print(f"Model {model_name} loaded successfully!")
return LOADED_MODELS[model_name]
@spaces.GPU(duration=120)
def transcribe_and_score(audio):
if audio is None:
return ""
model = get_model(MODEL_NAMES[0])
predictions = model.transcribe([audio])
item = predictions[0] if isinstance(predictions, list) else predictions
# Extract plain text regardless of return type
if hasattr(item, "text"):
text = item.text
elif isinstance(item, dict) and "text" in item:
text = item["text"]
elif isinstance(item, str):
text = item
else:
text = str(item)
return text.strip()
@spaces.GPU(duration=120)
def batch_transcribe(audio_files):
if not audio_files:
return []
model = get_model(MODEL_NAMES[0])
predictions = model.transcribe(audio_files)
texts = []
if isinstance(predictions, list):
for p in predictions:
if hasattr(p, "text"):
t = p.text
elif isinstance(p, dict) and "text" in p:
t = p["text"]
elif isinstance(p, str):
t = p
else:
t = str(p)
texts.append(t)
else:
texts.append(str(predictions))
return [[t.strip()] for t in texts]
with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
gr.Markdown("""
# EgypTalk-ASR-v2
Upload or record an audio file. This app transcribes audio using EgypTalk-ASR-v2.
""")
with gr.Tab("Single Test"):
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
transcribe_btn = gr.Button("Transcribe")
pred_output = gr.Textbox(label="Transcription")
transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output])
with gr.Tab("Batch Test"):
gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.")
audio_files = gr.Files(label="Audio Files (wav)")
batch_btn = gr.Button("Batch Transcribe")
preds_output = gr.Dataframe(headers=["Transcription"], label="Results")
batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output])
demo.launch(share=True)