Omartificial-Intelligence-Space's picture
Create app.py
3b9ff86 verified
raw
history blame
6.29 kB
import gradio as gr
import nemo.collections.asr as nemo_asr
from nemo.core import ModelPT
import torch
from evaluate import load as hf_load
import os
import spaces
# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")
# Model name from HuggingFace
MODEL_NAMES = [
"NAMAA-Space/EgypTalk-ASR-v2"
]
# Cache loaded models
LOADED_MODELS = {}
# Load WER and CER metrics from HuggingFace evaluate
hf_wer = hf_load("wer")
hf_cer = hf_load("cer")
def get_model(model_name):
if model_name not in LOADED_MODELS:
print(f"Loading model {model_name} on {device}...")
# Load model from HuggingFace using ModelPT
model = ModelPT.from_pretrained(model_name)
model = model.to(device)
model.eval()
LOADED_MODELS[model_name] = model
print(f"Model {model_name} loaded successfully!")
return LOADED_MODELS[model_name]
@spaces.GPU(duration=120)
def transcribe_and_score(audio, ground_truth, model_name):
if audio is None or ground_truth is None or model_name not in MODEL_NAMES:
return "", None, None
model = get_model(model_name)
# Use the correct transcribe API
predictions = model.transcribe([audio])
pred = predictions[0] if isinstance(predictions, list) else predictions
# Ensure both are strings and not empty
if not isinstance(ground_truth, str):
ground_truth = str(ground_truth)
if not isinstance(pred, str):
pred = str(pred)
# Strip whitespace
ground_truth = ground_truth.strip()
pred = pred.strip()
# Debug output
print(f"[DEBUG] Model: {model_name}")
print(f"[DEBUG] Ground truth: '{ground_truth}' (length: {len(ground_truth)})")
print(f"[DEBUG] Prediction: '{pred}' (length: {len(pred)})")
print(f"[DEBUG] Are they equal? {ground_truth == pred}")
print(f"[DEBUG] Ground truth bytes: {repr(ground_truth)}")
print(f"[DEBUG] Prediction bytes: {repr(pred)}")
if not ground_truth or not pred:
print("[DEBUG] Empty ground truth or prediction, returning 1.0")
return pred, 1.0, 1.0
# Calculate WER and CER
wer_score = hf_wer.compute(predictions=[pred], references=[ground_truth])
cer_score = hf_cer.compute(predictions=[pred], references=[ground_truth])
print(f"[DEBUG] WER: {wer_score}, CER: {cer_score}")
return pred, wer_score, cer_score
@spaces.GPU(duration=120)
def batch_transcribe_and_score(audio_files, ground_truths, model_name):
if not audio_files or not ground_truths or model_name not in MODEL_NAMES:
return [], [], [], None, None
if len(audio_files) != len(ground_truths):
return [], [], [], None, None
model = get_model(model_name)
# Use the correct transcribe API for batch
predictions = model.transcribe(audio_files)
results = []
pred_texts = []
for i, (audio_file, gt) in enumerate(zip(audio_files, ground_truths)):
pred = predictions[i] if isinstance(predictions, list) else predictions
if not isinstance(gt, str):
gt = str(gt)
if not isinstance(pred, str):
pred = str(pred)
gt = gt.strip()
pred = pred.strip()
if not gt or not pred:
wer_score = 1.0
cer_score = 1.0
else:
wer_score = hf_wer.compute(predictions=[pred], references=[gt])
cer_score = hf_cer.compute(predictions=[pred], references=[gt])
results.append([pred, wer_score, cer_score])
pred_texts.append(pred)
# Calculate average WER and CER
if pred_texts and ground_truths:
avg_wer = hf_wer.compute(predictions=pred_texts, references=ground_truths)
avg_cer = hf_cer.compute(predictions=pred_texts, references=ground_truths)
else:
avg_wer = None
avg_cer = None
return results, avg_wer, avg_cer
with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
gr.Markdown("""
# EgypTalk-ASR-v2
Upload an audio file and its ground truth text. This app transcribes and evaluates WER/CER using EgypTalk-ASR-v2.
""")
with gr.Tab("Single Test"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Audio File")
gt_input = gr.Textbox(label="Ground Truth Text")
model_choice = gr.Radio(MODEL_NAMES, value=MODEL_NAMES[0], label="Model")
transcribe_btn = gr.Button("Transcribe & Evaluate")
with gr.Row():
pred_output = gr.Textbox(label="Transcription")
wer_output = gr.Number(label="WER")
cer_output = gr.Number(label="CER")
transcribe_btn.click(transcribe_and_score, inputs=[audio_input, gt_input, model_choice], outputs=[pred_output, wer_output, cer_output])
with gr.Tab("Batch Test"):
gr.Markdown("Upload multiple audio files and a text file with ground truth (one line per audio, same order). Batch size is limited by GPU/CPU memory.")
audio_files = gr.Files(label="Audio Files (wav)")
gt_file = gr.File(label="Ground Truths (txt, one per line)")
model_choice_batch = gr.Radio(MODEL_NAMES, value=MODEL_NAMES[0], label="Model")
batch_btn = gr.Button("Batch Transcribe & Evaluate")
preds_output = gr.Dataframe(headers=["Prediction", "WER", "CER"], label="Results")
avg_wer_output = gr.Number(label="Average WER")
avg_cer_output = gr.Number(label="Average CER")
def batch_wrapper(audio_files, gt_file, model_name):
if not audio_files or not gt_file:
return [], None, None
with open(gt_file, 'r', encoding='utf-8') as f:
gts = [line.strip() for line in f if line.strip()]
audio_files_sorted = sorted(audio_files, key=lambda x: os.path.basename(x))
results, avg_wer, avg_cer = batch_transcribe_and_score(audio_files_sorted, gts, model_name)
return results, avg_wer, avg_cer
batch_btn.click(batch_wrapper, inputs=[audio_files, gt_file, model_choice_batch], outputs=[preds_output, avg_wer_output, avg_cer_output])
demo.launch(share=True)