Commit
6639d45
·
verified ·
1 Parent(s): b0a28f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -101
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import gradio as gr
2
- import nemo.collections.asr as nemo_asr
3
  from nemo.core import ModelPT
4
  import torch
5
- from evaluate import load as hf_load
6
  import os
7
  import spaces
8
 
@@ -19,9 +17,6 @@ MODEL_NAMES = [
19
  # Cache loaded models
20
  LOADED_MODELS = {}
21
 
22
- # Load WER and CER metrics from HuggingFace evaluate
23
- hf_wer = hf_load("wer")
24
- hf_cer = hf_load("cer")
25
 
26
  def get_model(model_name):
27
  if model_name not in LOADED_MODELS:
@@ -38,127 +33,56 @@ def get_model(model_name):
38
  return LOADED_MODELS[model_name]
39
 
40
  @spaces.GPU(duration=120)
41
- def transcribe_and_score(audio, ground_truth, model_name):
42
- if audio is None or ground_truth is None or model_name not in MODEL_NAMES:
43
- return "", None, None
44
- model = get_model(model_name)
45
 
46
  # Use the correct transcribe API
47
  predictions = model.transcribe([audio])
48
  pred = predictions[0] if isinstance(predictions, list) else predictions
49
 
50
- # Ensure both are strings and not empty
51
- if not isinstance(ground_truth, str):
52
- ground_truth = str(ground_truth)
53
  if not isinstance(pred, str):
54
  pred = str(pred)
55
 
56
- # Strip whitespace
57
- ground_truth = ground_truth.strip()
58
- pred = pred.strip()
59
-
60
- # Debug output
61
- print(f"[DEBUG] Model: {model_name}")
62
- print(f"[DEBUG] Ground truth: '{ground_truth}' (length: {len(ground_truth)})")
63
- print(f"[DEBUG] Prediction: '{pred}' (length: {len(pred)})")
64
- print(f"[DEBUG] Are they equal? {ground_truth == pred}")
65
- print(f"[DEBUG] Ground truth bytes: {repr(ground_truth)}")
66
- print(f"[DEBUG] Prediction bytes: {repr(pred)}")
67
-
68
- if not ground_truth or not pred:
69
- print("[DEBUG] Empty ground truth or prediction, returning 1.0")
70
- return pred, 1.0, 1.0
71
-
72
- # Calculate WER and CER
73
- wer_score = hf_wer.compute(predictions=[pred], references=[ground_truth])
74
- cer_score = hf_cer.compute(predictions=[pred], references=[ground_truth])
75
-
76
- print(f"[DEBUG] WER: {wer_score}, CER: {cer_score}")
77
- return pred, wer_score, cer_score
78
 
79
  @spaces.GPU(duration=120)
80
- def batch_transcribe_and_score(audio_files, ground_truths, model_name):
81
- if not audio_files or not ground_truths or model_name not in MODEL_NAMES:
82
- return [], [], [], None, None
83
- if len(audio_files) != len(ground_truths):
84
- return [], [], [], None, None
85
-
86
- model = get_model(model_name)
87
-
88
  # Use the correct transcribe API for batch
89
  predictions = model.transcribe(audio_files)
90
-
91
- results = []
92
- pred_texts = []
93
-
94
- for i, (audio_file, gt) in enumerate(zip(audio_files, ground_truths)):
95
- pred = predictions[i] if isinstance(predictions, list) else predictions
96
-
97
- if not isinstance(gt, str):
98
- gt = str(gt)
99
- if not isinstance(pred, str):
100
- pred = str(pred)
101
-
102
- gt = gt.strip()
103
- pred = pred.strip()
104
-
105
- if not gt or not pred:
106
- wer_score = 1.0
107
- cer_score = 1.0
108
- else:
109
- wer_score = hf_wer.compute(predictions=[pred], references=[gt])
110
- cer_score = hf_cer.compute(predictions=[pred], references=[gt])
111
-
112
- results.append([pred, wer_score, cer_score])
113
- pred_texts.append(pred)
114
-
115
- # Calculate average WER and CER
116
- if pred_texts and ground_truths:
117
- avg_wer = hf_wer.compute(predictions=pred_texts, references=ground_truths)
118
- avg_cer = hf_cer.compute(predictions=pred_texts, references=ground_truths)
119
  else:
120
- avg_wer = None
121
- avg_cer = None
122
-
123
- return results, avg_wer, avg_cer
124
 
125
  with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
126
  gr.Markdown("""
127
  # EgypTalk-ASR-v2
128
- Upload an audio file and its ground truth text. This app transcribes and evaluates WER/CER using EgypTalk-ASR-v2.
129
  """)
130
  with gr.Tab("Single Test"):
131
  with gr.Row():
132
  audio_input = gr.Audio(type="filepath", label="Audio File")
133
- gt_input = gr.Textbox(label="Ground Truth Text")
134
- model_choice = gr.Radio(MODEL_NAMES, value=MODEL_NAMES[0], label="Model")
135
- transcribe_btn = gr.Button("Transcribe & Evaluate")
136
- with gr.Row():
137
- pred_output = gr.Textbox(label="Transcription")
138
- wer_output = gr.Number(label="WER")
139
- cer_output = gr.Number(label="CER")
140
- transcribe_btn.click(transcribe_and_score, inputs=[audio_input, gt_input, model_choice], outputs=[pred_output, wer_output, cer_output])
141
 
142
  with gr.Tab("Batch Test"):
143
- gr.Markdown("Upload multiple audio files and a text file with ground truth (one line per audio, same order). Batch size is limited by GPU/CPU memory.")
144
  audio_files = gr.Files(label="Audio Files (wav)")
145
- gt_file = gr.File(label="Ground Truths (txt, one per line)")
146
- model_choice_batch = gr.Radio(MODEL_NAMES, value=MODEL_NAMES[0], label="Model")
147
- batch_btn = gr.Button("Batch Transcribe & Evaluate")
148
- preds_output = gr.Dataframe(headers=["Prediction", "WER", "CER"], label="Results")
149
- avg_wer_output = gr.Number(label="Average WER")
150
- avg_cer_output = gr.Number(label="Average CER")
151
-
152
- def batch_wrapper(audio_files, gt_file, model_name):
153
- if not audio_files or not gt_file:
154
- return [], None, None
155
- with open(gt_file, 'r', encoding='utf-8') as f:
156
- gts = [line.strip() for line in f if line.strip()]
157
- audio_files_sorted = sorted(audio_files, key=lambda x: os.path.basename(x))
158
- results, avg_wer, avg_cer = batch_transcribe_and_score(audio_files_sorted, gts, model_name)
159
- return results, avg_wer, avg_cer
160
 
161
- batch_btn.click(batch_wrapper, inputs=[audio_files, gt_file, model_choice_batch], outputs=[preds_output, avg_wer_output, avg_cer_output])
162
 
163
 
164
  demo.launch(share=True)
 
1
  import gradio as gr
 
2
  from nemo.core import ModelPT
3
  import torch
 
4
  import os
5
  import spaces
6
 
 
17
  # Cache loaded models
18
  LOADED_MODELS = {}
19
 
 
 
 
20
 
21
  def get_model(model_name):
22
  if model_name not in LOADED_MODELS:
 
33
  return LOADED_MODELS[model_name]
34
 
35
  @spaces.GPU(duration=120)
36
+ def transcribe_and_score(audio):
37
+ if audio is None:
38
+ return ""
39
+ model = get_model(MODEL_NAMES[0])
40
 
41
  # Use the correct transcribe API
42
  predictions = model.transcribe([audio])
43
  pred = predictions[0] if isinstance(predictions, list) else predictions
44
 
 
 
 
45
  if not isinstance(pred, str):
46
  pred = str(pred)
47
 
48
+ return pred.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @spaces.GPU(duration=120)
51
+ def batch_transcribe(audio_files):
52
+ if not audio_files:
53
+ return []
54
+ model = get_model(MODEL_NAMES[0])
55
+
 
 
 
56
  # Use the correct transcribe API for batch
57
  predictions = model.transcribe(audio_files)
58
+
59
+ if isinstance(predictions, list):
60
+ texts = [p if isinstance(p, str) else str(p) for p in predictions]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
62
+ texts = [str(predictions)]
63
+
64
+ # Return as rows for a single-column dataframe
65
+ return [[t.strip()] for t in texts]
66
 
67
  with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
68
  gr.Markdown("""
69
  # EgypTalk-ASR-v2
70
+ Upload an audio file. This app transcribes audio using EgypTalk-ASR-v2.
71
  """)
72
  with gr.Tab("Single Test"):
73
  with gr.Row():
74
  audio_input = gr.Audio(type="filepath", label="Audio File")
75
+ transcribe_btn = gr.Button("Transcribe")
76
+ pred_output = gr.Textbox(label="Transcription")
77
+ transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output])
 
 
 
 
 
78
 
79
  with gr.Tab("Batch Test"):
80
+ gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.")
81
  audio_files = gr.Files(label="Audio Files (wav)")
82
+ batch_btn = gr.Button("Batch Transcribe")
83
+ preds_output = gr.Dataframe(headers=["Transcription"], label="Results")
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output])
86
 
87
 
88
  demo.launch(share=True)