File size: 20,348 Bytes
ad915da
 
 
 
 
3c6aeb7
7ec068d
44ea2d4
d7d6438
 
ad915da
3c6aeb7
 
ad915da
4e73867
 
 
 
 
 
 
 
 
 
 
 
 
 
ad915da
44ea2d4
 
 
 
 
 
 
 
 
 
 
 
d7d6438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92a4ace
d9795b9
 
 
 
 
 
 
 
 
92a4ace
 
 
 
 
 
 
 
 
 
 
 
 
 
d9795b9
 
 
 
 
 
d7d6438
 
ad915da
d7d6438
ad915da
7ec068d
d9795b9
 
 
 
 
44ea2d4
d9795b9
 
 
 
d7d6438
44ea2d4
d9795b9
 
 
 
 
 
 
 
 
 
 
 
 
 
44ea2d4
d9795b9
d7d6438
 
 
 
92a4ace
 
44ea2d4
d9795b9
 
 
 
 
 
 
 
92a4ace
7ec068d
d7d6438
d9795b9
d7d6438
 
 
 
 
 
 
 
 
d9795b9
d7d6438
d9795b9
 
 
 
 
 
4e73867
d7d6438
 
92a4ace
d7d6438
7ec068d
d7d6438
 
 
 
 
 
d9795b9
d7d6438
 
 
 
 
 
 
 
 
 
 
 
 
92a4ace
d7d6438
 
 
 
 
 
 
92a4ace
d7d6438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92a4ace
7ec068d
92a4ace
 
7ec068d
 
d7d6438
92a4ace
 
d7d6438
 
 
 
 
 
 
 
7ec068d
d7d6438
ad915da
7ec068d
 
 
d7d6438
ad915da
4e73867
ad915da
7ec068d
44ea2d4
 
 
 
 
4e73867
 
ad915da
4e73867
d9795b9
7ec068d
 
c7f8633
 
 
 
4e73867
 
 
7ec068d
d9795b9
7ec068d
ad915da
7ec068d
4e73867
44ea2d4
7ec068d
 
c7f8633
7ec068d
 
4e73867
 
7ec068d
4e73867
44ea2d4
d7d6438
4e73867
d7d6438
7ec068d
c7f8633
4e73867
d7d6438
 
 
c7f8633
7ec068d
 
c7f8633
4e73867
d7d6438
 
 
c7f8633
7ec068d
c7f8633
7ec068d
c7f8633
 
 
 
 
d9795b9
 
c7f8633
d7d6438
7ec068d
c7f8633
4e73867
d7d6438
 
 
c7f8633
7ec068d
 
d9795b9
c7f8633
 
d7d6438
 
 
c7f8633
 
9f029d4
 
381227f
9f029d4
 
 
 
 
 
 
 
 
 
7ec068d
c7f8633
9f029d4
 
 
 
 
 
 
c7f8633
9f029d4
 
381227f
c7f8633
ad915da
7ec068d
 
 
 
ad915da
 
 
7ec068d
 
 
 
d7d6438
 
9f029d4
 
 
381227f
9f029d4
 
c7f8633
9f029d4
 
 
 
 
 
7ec068d
 
3c6aeb7
7ec068d
 
 
 
ad915da
 
 
 
 
 
3c6aeb7
ad915da
7ec068d
 
 
ad915da
 
4e73867
 
d7d6438
ad915da
 
 
 
 
44ea2d4
ad915da
7ec068d
 
 
 
 
 
 
 
 
 
 
 
 
44ea2d4
4e73867
7ec068d
 
 
 
 
 
ad915da
 
d9795b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np
from functools import lru_cache
import traceback
import re
import string
from collections import Counter

# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
    try:
        # Load only the test dataset by specifying the split
        dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
        return dataset
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        # Try loading with explicit file path if the default loading fails
        try:
            dataset = load_dataset("parquet", 
                                  data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
            return dataset
        except Exception as e2:
            print(f"Error loading with explicit path: {str(e2)}")
            raise

# Preprocess text for better WER calculation
def preprocess_text(text):
    if not text or not isinstance(text, str):
        return ""
    # Convert to lowercase
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Simple language model scoring - count n-grams
def score_hypothesis(hypothesis, n=4):
    """Score a hypothesis using simple n-gram statistics"""
    if not hypothesis:
        return 0
    
    words = hypothesis.split()
    if len(words) < n:
        return len(words)  # Just return word count for very short texts
    
    # Count n-grams
    ngrams = []
    for i in range(len(words) - n + 1):
        ngram = ' '.join(words[i:i+n])
        ngrams.append(ngram)
    
    # More unique n-grams might indicate better fluency
    unique_ngrams = len(set(ngrams))
    total_ngrams = len(ngrams)
    
    # Score is a combination of length and n-gram variety
    score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
    return score

# N-best LM ranking approach
def get_best_hypothesis_lm(hypotheses):
    """Choose the best hypothesis using a simple language model approach"""
    if not hypotheses:
        return ""
    
    # Convert to list if it's not already
    if isinstance(hypotheses, str):
        return hypotheses
    
    # Ensure we have a list of strings
    hypothesis_list = []
    for h in hypotheses:
        if isinstance(h, str):
            hypothesis_list.append(preprocess_text(h))
    
    if not hypothesis_list:
        return ""
    
    # Score each hypothesis and choose the best one
    scores = [(score_hypothesis(h), h) for h in hypothesis_list]
    best_hypothesis = max(scores, key=lambda x: x[0])[1]
    return best_hypothesis

# N-best correction approach
def correct_hypotheses(hypotheses):
    """Simple n-best correction by voting on words"""
    if not hypotheses:
        return ""
    
    # Convert to list if it's not already
    if isinstance(hypotheses, str):
        return hypotheses
    
    # Ensure we have a list of strings
    hypothesis_list = []
    for h in hypotheses:
        if isinstance(h, str):
            hypothesis_list.append(preprocess_text(h))
    
    if not hypothesis_list:
        return ""
    
    # Split hypotheses into words
    word_lists = [h.split() for h in hypothesis_list]
    
    # Find the most common length
    lengths = [len(words) for words in word_lists]
    if not lengths:
        return ""
    
    most_common_length = Counter(lengths).most_common(1)[0][0]
    
    # Only consider hypotheses with the most common length
    filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
    
    if not filtered_word_lists:
        # Fall back to the longest hypothesis if filtering removed everything
        return max(hypothesis_list, key=len)
    
    # Vote on each word position
    corrected_words = []
    for i in range(most_common_length):
        position_words = [words[i] for words in filtered_word_lists]
        most_common_word = Counter(position_words).most_common(1)[0][0]
        corrected_words.append(most_common_word)
    
    # Join the corrected words
    return ' '.join(corrected_words)

# Fix the Levenshtein distance calculation to avoid dependence on jiwer internals
def calculate_simple_wer(reference, hypothesis):
    """Calculate WER using a simple word-based approach"""
    if not reference or not hypothesis:
        return 1.0  # Maximum error if either is empty
        
    # Split into words
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    
    # Use editdistance package instead of jiwer internals
    try:
        import editdistance
        distance = editdistance.eval(ref_words, hyp_words)
    except ImportError:
        # Fallback to simple jiwer calculation
        try:
            # Try using the standard jiwer implementation
            wer_value = jiwer.wer(reference, hypothesis)
            return wer_value
        except Exception:
            # If all else fails, return 1.0 (maximum error)
            print("Error calculating WER - fallback to maximum error")
            return 1.0
    
    # WER calculation
    if len(ref_words) == 0:
        return 1.0
    return float(distance) / float(len(ref_words))

# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples):
    if not examples:
        return 0.0, 0.0, 0.0
    
    try:
        # Check if examples is a Dataset or a list
        is_dataset = hasattr(examples, 'features')
        
        # Get the first example for inspection
        if is_dataset and len(examples) > 0:
            example = examples[0]
        elif not is_dataset and len(examples) > 0:
            example = examples[0]
        else:
            print("No examples found")
            return np.nan, np.nan, np.nan
            
        print("\n===== EXAMPLE DATA INSPECTION =====")
        print(f"Keys in example: {example.keys()}")
        
        # Try different possible field names
        possible_reference_fields = ["transcription", "reference", "ground_truth", "target"]
        possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"]
        
        for field in possible_reference_fields:
            if field in example:
                print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...")
        
        for field in possible_hypothesis_fields:
            if field in example:
                print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...")
        
        # Process each example in the dataset
        wer_values_no_lm = []
        wer_values_lm_ranking = []
        wer_values_n_best_correction = []
        
        valid_count = 0
        skipped_count = 0
        
        # Determine how to iterate based on type
        items_to_process = examples
        if is_dataset:
            # Limit to first 200 examples for efficiency
            items_to_process = examples.select(range(min(200, len(examples))))
        else:
            items_to_process = examples[:200]  # First 200 examples
        
        for i, ex in enumerate(items_to_process):
            try:
                # Get reference transcription
                transcription = ex.get("transcription")
                if not transcription or not isinstance(transcription, str):
                    skipped_count += 1
                    continue
                
                # Process the reference
                reference = preprocess_text(transcription)
                if not reference:
                    skipped_count += 1
                    continue
                
                # Get 1-best hypothesis for baseline
                input1 = ex.get("input1")
                if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
                    if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
                        input1 = ex["hypothesis"][0]
                    elif isinstance(ex["hypothesis"], str):
                        input1 = ex["hypothesis"]
                
                # Get n-best hypotheses for other methods
                n_best_hypotheses = ex.get("hypothesis", [])
                
                # Process and evaluate all methods
                
                # Method 1: No LM (1-best ASR output)
                if input1 and isinstance(input1, str):
                    no_lm_hyp = preprocess_text(input1)
                    if no_lm_hyp:
                        wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
                        wer_values_no_lm.append(wer_no_lm)
                
                # Method 2: LM ranking (best of n-best)
                if n_best_hypotheses:
                    lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
                    if lm_best_hyp:
                        wer_lm = calculate_simple_wer(reference, lm_best_hyp)
                        wer_values_lm_ranking.append(wer_lm)
                
                # Method 3: N-best correction (voting among n-best)
                if n_best_hypotheses:
                    corrected_hyp = correct_hypotheses(n_best_hypotheses)
                    if corrected_hyp:
                        wer_corrected = calculate_simple_wer(reference, corrected_hyp)
                        wer_values_n_best_correction.append(wer_corrected)
                
                # Count as valid if at least one method worked
                if (wer_values_no_lm and i == len(wer_values_no_lm) - 1) or \
                   (wer_values_lm_ranking and i == len(wer_values_lm_ranking) - 1) or \
                   (wer_values_n_best_correction and i == len(wer_values_n_best_correction) - 1):
                    valid_count += 1
                else:
                    skipped_count += 1
                
                # Print debug info for a few examples
                if i < 2:
                    print(f"\nExample {i} inspection:")
                    print(f"  Reference: '{reference}'")
                    
                    if input1 and isinstance(input1, str):
                        no_lm_hyp = preprocess_text(input1)
                        print(f"  No LM (1-best): '{no_lm_hyp}'")
                        if no_lm_hyp:
                            wer = calculate_simple_wer(reference, no_lm_hyp)
                            print(f"  No LM WER: {wer:.4f}")
                    
                    if n_best_hypotheses:
                        print(f"  N-best count: {len(n_best_hypotheses) if isinstance(n_best_hypotheses, list) else 'not a list'}")
                        lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
                        print(f"  LM ranking best: '{lm_best_hyp}'")
                        if lm_best_hyp:
                            wer = calculate_simple_wer(reference, lm_best_hyp)
                            print(f"  LM ranking WER: {wer:.4f}")
                        
                        corrected_hyp = correct_hypotheses(n_best_hypotheses)
                        print(f"  N-best correction: '{corrected_hyp}'")
                        if corrected_hyp:
                            wer = calculate_simple_wer(reference, corrected_hyp)
                            print(f"  N-best correction WER: {wer:.4f}")
                
            except Exception as ex_error:
                print(f"Error processing example {i}: {str(ex_error)}")
                skipped_count += 1
                continue
        
        # Calculate average WER for each method
        print(f"\nProcessing summary: Valid pairs: {valid_count}, Skipped: {skipped_count}")
        
        no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
        lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
        n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
        
        print(f"Calculated WERs:")
        print(f"  No LM: {len(wer_values_no_lm)} pairs, avg WER: {no_lm_wer:.4f}")
        print(f"  LM Ranking: {len(wer_values_lm_ranking)} pairs, avg WER: {lm_ranking_wer:.4f}")
        print(f"  N-best Correction: {len(wer_values_n_best_correction)} pairs, avg WER: {n_best_correction_wer:.4f}")
        
        return no_lm_wer, lm_ranking_wer, n_best_correction_wer
    
    except Exception as e:
        print(f"Error in calculate_wer: {str(e)}")
        print(traceback.format_exc())
        return np.nan, np.nan, np.nan

# Get WER metrics by source 
def get_wer_metrics(dataset):
    try:
        # Print dataset info
        print(f"\n===== DATASET INFO =====")
        print(f"Dataset size: {len(dataset)}")
        print(f"Dataset features: {dataset.features}")
        
        # Group examples by source
        examples_by_source = {}
        
        # Process all examples
        for i, ex in enumerate(dataset):
            try:
                source = ex.get("source", "unknown")
                # Skip all_et05_real as requested
                if source == "all_et05_real":
                    continue
                    
                if source not in examples_by_source:
                    examples_by_source[source] = []
                examples_by_source[source].append(ex)
            except Exception as e:
                print(f"Error processing example {i}: {str(e)}")
                continue
        
        # Get all unique sources
        all_sources = sorted(examples_by_source.keys())
        print(f"Found sources: {all_sources}")
        
        # Calculate metrics for each source
        source_results = {}
        for source in all_sources:
            try:
                examples = examples_by_source.get(source, [])
                count = len(examples)
                
                if count > 0:
                    print(f"\nCalculating WER for source {source} with {count} examples")
                    no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
                else:
                    no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
                
                source_results[source] = {
                    "Count": count,
                    "No LM Baseline": no_lm_wer,
                    "N-best LM Ranking": lm_ranking_wer,
                    "N-best Correction": n_best_wer
                }
            except Exception as e:
                print(f"Error processing source {source}: {str(e)}")
                source_results[source] = {
                    "Count": 0,
                    "No LM Baseline": np.nan,
                    "N-best LM Ranking": np.nan,
                    "N-best Correction": np.nan
                }
        
        # Calculate overall metrics with a sample but excluding all_et05_real
        try:
            # Create a filtered dataset without all_et05_real
            filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
            total_count = len(filtered_dataset)
            print(f"\nCalculating overall WER with a sample of examples (excluding all_et05_real)")
            
            # Sample for calculation
            sample_size = min(500, total_count)
            sample_dataset = filtered_dataset[:sample_size]
            no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
            
            source_results["OVERALL"] = {
                "Count": total_count,
                "No LM Baseline": no_lm_wer,
                "N-best LM Ranking": lm_ranking_wer,
                "N-best Correction": n_best_wer
            }
        except Exception as e:
            print(f"Error calculating overall metrics: {str(e)}")
            print(traceback.format_exc())
            source_results["OVERALL"] = {
                "Count": len(filtered_dataset),
                "No LM Baseline": np.nan,
                "N-best LM Ranking": np.nan,
                "N-best Correction": np.nan
            }
        
        # Create flat DataFrame with labels in the first column
        rows = []
        
        # First add row for number of examples
        example_row = {"Metric": "Number of Examples"}
        for source in all_sources + ["OVERALL"]:
            example_row[source] = source_results[source]["Count"]
        rows.append(example_row)
        
        # Then add rows for each WER method
        no_lm_row = {"Metric": "Word Error Rate (No LM)"}
        lm_ranking_row = {"Metric": "Word Error Rate (N-best LM Ranking)"}
        n_best_row = {"Metric": "Word Error Rate (N-best Correction)"}
        
        for source in all_sources + ["OVERALL"]:
            no_lm_row[source] = source_results[source]["No LM Baseline"]
            lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
            n_best_row[source] = source_results[source]["N-best Correction"]
        
        rows.append(no_lm_row)
        rows.append(lm_ranking_row)
        rows.append(n_best_row)
        
        # Create DataFrame from rows
        result_df = pd.DataFrame(rows)
        
        return result_df
    
    except Exception as e:
        print(f"Error in get_wer_metrics: {str(e)}")
        print(traceback.format_exc())
        return pd.DataFrame([{"Error": str(e)}])

# Format the dataframe for display
def format_dataframe(df):
    try:
        # Use vectorized operations instead of apply
        df = df.copy()
        
        # Find the rows containing WER values
        wer_row_indices = []
        for i, metric in enumerate(df["Metric"]):
            if "WER" in metric or "Error Rate" in metric:
                wer_row_indices.append(i)
        
        # Format WER values
        for idx in wer_row_indices:
            for col in df.columns:
                if col != "Metric":  # Skip the metric column
                    value = df.loc[idx, col]
                    if pd.notna(value):
                        df.loc[idx, col] = f"{value:.4f}"
                    else:
                        df.loc[idx, col] = "N/A"
        
        return df
    
    except Exception as e:
        print(f"Error in format_dataframe: {str(e)}")
        print(traceback.format_exc())
        return pd.DataFrame([{"Error": str(e)}])

# Main function to create the leaderboard
def create_leaderboard():
    try:
        dataset = load_data()
        metrics_df = get_wer_metrics(dataset)
        return format_dataframe(metrics_df)
    except Exception as e:
        error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return pd.DataFrame([{"Error": error_msg}])

# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
    gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
    gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
    
    with gr.Row():
        refresh_btn = gr.Button("Refresh Leaderboard")
    
    with gr.Row():
        error_output = gr.Textbox(label="Debug Information", visible=True, lines=10)
    
    with gr.Row():
        try:
            initial_df = create_leaderboard()
            leaderboard = gr.DataFrame(initial_df)
        except Exception as e:
            error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            error_output.update(value=error_msg)
            leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}]))
    
    def refresh_and_report():
        try:
            df = create_leaderboard()
            debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information."
            return df, debug_info
        except Exception as e:
            error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            return pd.DataFrame([{"Error": error_msg}]), error_msg
    
    refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output])

if __name__ == "__main__":
    demo.launch()