File size: 3,239 Bytes
3b9ff86 2349faf 3b9ff86 2349faf 3b9ff86 2349faf 3b9ff86 2349faf 3b9ff86 6639d45 59b71d1 6639d45 3b9ff86 763786a 59b71d1 3b9ff86 6639d45 3b9ff86 763786a 6639d45 763786a 3b9ff86 763786a 6639d45 3b9ff86 763786a 3b9ff86 763786a 6639d45 59b71d1 3b9ff86 6639d45 3b9ff86 6639d45 3b9ff86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
from nemo.collections.asr.models import ASRModel
import torch
import os
import spaces
from huggingface_hub import hf_hub_download
# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")
# Model name from HuggingFace
MODEL_NAMES = [
"NAMAA-Space/EgypTalk-ASR-v2"
]
# Cache loaded models
LOADED_MODELS = {}
# Filename of the NeMo checkpoint inside the HF repo
NEMO_FILENAME = "asr-egyptian-nemo-v2.0.nemo"
def get_model(model_name):
if model_name not in LOADED_MODELS:
print(f"Loading model {model_name} on {device}...")
# Download the .nemo artifact from HF and restore
nemo_path = hf_hub_download(repo_id=model_name, filename=NEMO_FILENAME)
model = ASRModel.restore_from(nemo_path)
model = model.to(device)
model.eval()
LOADED_MODELS[model_name] = model
print(f"Model {model_name} loaded successfully!")
return LOADED_MODELS[model_name]
@spaces.GPU(duration=120)
def transcribe_and_score(audio):
if audio is None:
return ""
model = get_model(MODEL_NAMES[0])
predictions = model.transcribe([audio])
item = predictions[0] if isinstance(predictions, list) else predictions
# Extract plain text regardless of return type
if hasattr(item, "text"):
text = item.text
elif isinstance(item, dict) and "text" in item:
text = item["text"]
elif isinstance(item, str):
text = item
else:
text = str(item)
return text.strip()
@spaces.GPU(duration=120)
def batch_transcribe(audio_files):
if not audio_files:
return []
model = get_model(MODEL_NAMES[0])
predictions = model.transcribe(audio_files)
texts = []
if isinstance(predictions, list):
for p in predictions:
if hasattr(p, "text"):
t = p.text
elif isinstance(p, dict) and "text" in p:
t = p["text"]
elif isinstance(p, str):
t = p
else:
t = str(p)
texts.append(t)
else:
texts.append(str(predictions))
return [[t.strip()] for t in texts]
with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
gr.Markdown("""
# EgypTalk-ASR-v2
Upload or record an audio file. This app transcribes audio using EgypTalk-ASR-v2.
""")
with gr.Tab("Single Test"):
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
transcribe_btn = gr.Button("Transcribe")
pred_output = gr.Textbox(label="Transcription")
transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output])
with gr.Tab("Batch Test"):
gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.")
audio_files = gr.Files(label="Audio Files (wav)")
batch_btn = gr.Button("Batch Transcribe")
preds_output = gr.Dataframe(headers=["Transcription"], label="Results")
batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output])
demo.launch(share=True) |