Omartificial-Intelligence-Space's picture
Update app.py
6639d45 verified
raw
history blame
2.66 kB
import gradio as gr
from nemo.core import ModelPT
import torch
import os
import spaces
# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")
# Model name from HuggingFace
MODEL_NAMES = [
"NAMAA-Space/EgypTalk-ASR-v2"
]
# Cache loaded models
LOADED_MODELS = {}
def get_model(model_name):
if model_name not in LOADED_MODELS:
print(f"Loading model {model_name} on {device}...")
# Load model from HuggingFace using ModelPT
model = ModelPT.from_pretrained(model_name)
model = model.to(device)
model.eval()
LOADED_MODELS[model_name] = model
print(f"Model {model_name} loaded successfully!")
return LOADED_MODELS[model_name]
@spaces.GPU(duration=120)
def transcribe_and_score(audio):
if audio is None:
return ""
model = get_model(MODEL_NAMES[0])
# Use the correct transcribe API
predictions = model.transcribe([audio])
pred = predictions[0] if isinstance(predictions, list) else predictions
if not isinstance(pred, str):
pred = str(pred)
return pred.strip()
@spaces.GPU(duration=120)
def batch_transcribe(audio_files):
if not audio_files:
return []
model = get_model(MODEL_NAMES[0])
# Use the correct transcribe API for batch
predictions = model.transcribe(audio_files)
if isinstance(predictions, list):
texts = [p if isinstance(p, str) else str(p) for p in predictions]
else:
texts = [str(predictions)]
# Return as rows for a single-column dataframe
return [[t.strip()] for t in texts]
with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
gr.Markdown("""
# EgypTalk-ASR-v2
Upload an audio file. This app transcribes audio using EgypTalk-ASR-v2.
""")
with gr.Tab("Single Test"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Audio File")
transcribe_btn = gr.Button("Transcribe")
pred_output = gr.Textbox(label="Transcription")
transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output])
with gr.Tab("Batch Test"):
gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.")
audio_files = gr.Files(label="Audio Files (wav)")
batch_btn = gr.Button("Batch Transcribe")
preds_output = gr.Dataframe(headers=["Transcription"], label="Results")
batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output])
demo.launch(share=True)