Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import nemo.collections.asr as nemo_asr | |
| from nemo.core import ModelPT | |
| import torch | |
| from evaluate import load as hf_load | |
| import os | |
| import spaces | |
| # Check for GPU support and configure appropriately | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| zero = torch.Tensor([0]).to(device) | |
| print(f"Device being used: {zero.device}") | |
| # Model name from HuggingFace | |
| MODEL_NAMES = [ | |
| "NAMAA-Space/EgypTalk-ASR-v2" | |
| ] | |
| # Cache loaded models | |
| LOADED_MODELS = {} | |
| # Load WER and CER metrics from HuggingFace evaluate | |
| hf_wer = hf_load("wer") | |
| hf_cer = hf_load("cer") | |
| def get_model(model_name): | |
| if model_name not in LOADED_MODELS: | |
| print(f"Loading model {model_name} on {device}...") | |
| # Load model from HuggingFace using ModelPT | |
| model = ModelPT.from_pretrained(model_name) | |
| model = model.to(device) | |
| model.eval() | |
| LOADED_MODELS[model_name] = model | |
| print(f"Model {model_name} loaded successfully!") | |
| return LOADED_MODELS[model_name] | |
| def transcribe_and_score(audio, ground_truth, model_name): | |
| if audio is None or ground_truth is None or model_name not in MODEL_NAMES: | |
| return "", None, None | |
| model = get_model(model_name) | |
| # Use the correct transcribe API | |
| predictions = model.transcribe([audio]) | |
| pred = predictions[0] if isinstance(predictions, list) else predictions | |
| # Ensure both are strings and not empty | |
| if not isinstance(ground_truth, str): | |
| ground_truth = str(ground_truth) | |
| if not isinstance(pred, str): | |
| pred = str(pred) | |
| # Strip whitespace | |
| ground_truth = ground_truth.strip() | |
| pred = pred.strip() | |
| # Debug output | |
| print(f"[DEBUG] Model: {model_name}") | |
| print(f"[DEBUG] Ground truth: '{ground_truth}' (length: {len(ground_truth)})") | |
| print(f"[DEBUG] Prediction: '{pred}' (length: {len(pred)})") | |
| print(f"[DEBUG] Are they equal? {ground_truth == pred}") | |
| print(f"[DEBUG] Ground truth bytes: {repr(ground_truth)}") | |
| print(f"[DEBUG] Prediction bytes: {repr(pred)}") | |
| if not ground_truth or not pred: | |
| print("[DEBUG] Empty ground truth or prediction, returning 1.0") | |
| return pred, 1.0, 1.0 | |
| # Calculate WER and CER | |
| wer_score = hf_wer.compute(predictions=[pred], references=[ground_truth]) | |
| cer_score = hf_cer.compute(predictions=[pred], references=[ground_truth]) | |
| print(f"[DEBUG] WER: {wer_score}, CER: {cer_score}") | |
| return pred, wer_score, cer_score | |
| def batch_transcribe_and_score(audio_files, ground_truths, model_name): | |
| if not audio_files or not ground_truths or model_name not in MODEL_NAMES: | |
| return [], [], [], None, None | |
| if len(audio_files) != len(ground_truths): | |
| return [], [], [], None, None | |
| model = get_model(model_name) | |
| # Use the correct transcribe API for batch | |
| predictions = model.transcribe(audio_files) | |
| results = [] | |
| pred_texts = [] | |
| for i, (audio_file, gt) in enumerate(zip(audio_files, ground_truths)): | |
| pred = predictions[i] if isinstance(predictions, list) else predictions | |
| if not isinstance(gt, str): | |
| gt = str(gt) | |
| if not isinstance(pred, str): | |
| pred = str(pred) | |
| gt = gt.strip() | |
| pred = pred.strip() | |
| if not gt or not pred: | |
| wer_score = 1.0 | |
| cer_score = 1.0 | |
| else: | |
| wer_score = hf_wer.compute(predictions=[pred], references=[gt]) | |
| cer_score = hf_cer.compute(predictions=[pred], references=[gt]) | |
| results.append([pred, wer_score, cer_score]) | |
| pred_texts.append(pred) | |
| # Calculate average WER and CER | |
| if pred_texts and ground_truths: | |
| avg_wer = hf_wer.compute(predictions=pred_texts, references=ground_truths) | |
| avg_cer = hf_cer.compute(predictions=pred_texts, references=ground_truths) | |
| else: | |
| avg_wer = None | |
| avg_cer = None | |
| return results, avg_wer, avg_cer | |
| with gr.Blocks(title="EgypTalk-ASR-v2") as demo: | |
| gr.Markdown(""" | |
| # EgypTalk-ASR-v2 | |
| Upload an audio file and its ground truth text. This app transcribes and evaluates WER/CER using EgypTalk-ASR-v2. | |
| """) | |
| with gr.Tab("Single Test"): | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", label="Audio File") | |
| gt_input = gr.Textbox(label="Ground Truth Text") | |
| model_choice = gr.Radio(MODEL_NAMES, value=MODEL_NAMES[0], label="Model") | |
| transcribe_btn = gr.Button("Transcribe & Evaluate") | |
| with gr.Row(): | |
| pred_output = gr.Textbox(label="Transcription") | |
| wer_output = gr.Number(label="WER") | |
| cer_output = gr.Number(label="CER") | |
| transcribe_btn.click(transcribe_and_score, inputs=[audio_input, gt_input, model_choice], outputs=[pred_output, wer_output, cer_output]) | |
| with gr.Tab("Batch Test"): | |
| gr.Markdown("Upload multiple audio files and a text file with ground truth (one line per audio, same order). Batch size is limited by GPU/CPU memory.") | |
| audio_files = gr.Files(label="Audio Files (wav)") | |
| gt_file = gr.File(label="Ground Truths (txt, one per line)") | |
| model_choice_batch = gr.Radio(MODEL_NAMES, value=MODEL_NAMES[0], label="Model") | |
| batch_btn = gr.Button("Batch Transcribe & Evaluate") | |
| preds_output = gr.Dataframe(headers=["Prediction", "WER", "CER"], label="Results") | |
| avg_wer_output = gr.Number(label="Average WER") | |
| avg_cer_output = gr.Number(label="Average CER") | |
| def batch_wrapper(audio_files, gt_file, model_name): | |
| if not audio_files or not gt_file: | |
| return [], None, None | |
| with open(gt_file, 'r', encoding='utf-8') as f: | |
| gts = [line.strip() for line in f if line.strip()] | |
| audio_files_sorted = sorted(audio_files, key=lambda x: os.path.basename(x)) | |
| results, avg_wer, avg_cer = batch_transcribe_and_score(audio_files_sorted, gts, model_name) | |
| return results, avg_wer, avg_cer | |
| batch_btn.click(batch_wrapper, inputs=[audio_files, gt_file, model_choice_batch], outputs=[preds_output, avg_wer_output, avg_cer_output]) | |
| demo.launch(share=True) |