import gradio as gr import pandas as pd from datasets import load_dataset import jiwer import numpy as np from functools import lru_cache import traceback import re import string from collections import Counter # Cache the dataset loading to avoid reloading on refresh @lru_cache(maxsize=1) def load_data(): try: # Load only the test dataset by specifying the split dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test") return dataset except Exception as e: print(f"Error loading dataset: {str(e)}") # Try loading with explicit file path if the default loading fails try: dataset = load_dataset("parquet", data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet") return dataset except Exception as e2: print(f"Error loading with explicit path: {str(e2)}") raise # Preprocess text for better WER calculation def preprocess_text(text): if not text or not isinstance(text, str): return "" # Convert to lowercase text = text.lower() # Remove punctuation text = re.sub(r'[^\w\s]', '', text) # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() return text # Simple language model scoring - count n-grams def score_hypothesis(hypothesis, n=4): """Score a hypothesis using simple n-gram statistics""" if not hypothesis: return 0 words = hypothesis.split() if len(words) < n: return len(words) # Just return word count for very short texts # Count n-grams ngrams = [] for i in range(len(words) - n + 1): ngram = ' '.join(words[i:i+n]) ngrams.append(ngram) # More unique n-grams might indicate better fluency unique_ngrams = len(set(ngrams)) total_ngrams = len(ngrams) # Score is a combination of length and n-gram variety score = len(words) + unique_ngrams/max(1, total_ngrams) * 5 return score # N-best LM ranking approach def get_best_hypothesis_lm(hypotheses): """Choose the best hypothesis using a simple language model approach""" if not hypotheses: return "" # Convert to list if it's not already if isinstance(hypotheses, str): return hypotheses # Ensure we have a list of strings hypothesis_list = [] for h in hypotheses: if isinstance(h, str): hypothesis_list.append(preprocess_text(h)) if not hypothesis_list: return "" # Score each hypothesis and choose the best one scores = [(score_hypothesis(h), h) for h in hypothesis_list] best_hypothesis = max(scores, key=lambda x: x[0])[1] return best_hypothesis # N-best correction approach def correct_hypotheses(hypotheses): """Simple n-best correction by voting on words""" if not hypotheses: return "" # Convert to list if it's not already if isinstance(hypotheses, str): return hypotheses # Ensure we have a list of strings hypothesis_list = [] for h in hypotheses: if isinstance(h, str): hypothesis_list.append(preprocess_text(h)) if not hypothesis_list: return "" # Split hypotheses into words word_lists = [h.split() for h in hypothesis_list] # Find the most common length lengths = [len(words) for words in word_lists] if not lengths: return "" most_common_length = Counter(lengths).most_common(1)[0][0] # Only consider hypotheses with the most common length filtered_word_lists = [words for words in word_lists if len(words) == most_common_length] if not filtered_word_lists: # Fall back to the longest hypothesis if filtering removed everything return max(hypothesis_list, key=len) # Vote on each word position corrected_words = [] for i in range(most_common_length): position_words = [words[i] for words in filtered_word_lists] most_common_word = Counter(position_words).most_common(1)[0][0] corrected_words.append(most_common_word) # Join the corrected words return ' '.join(corrected_words) # Fix the Levenshtein distance calculation to avoid dependence on jiwer internals def calculate_simple_wer(reference, hypothesis): """Calculate WER using a simple word-based approach""" if not reference or not hypothesis: return 1.0 # Maximum error if either is empty # Split into words ref_words = reference.split() hyp_words = hypothesis.split() # Use editdistance package instead of jiwer internals try: import editdistance distance = editdistance.eval(ref_words, hyp_words) except ImportError: # Fallback to simple jiwer calculation try: # Try using the standard jiwer implementation wer_value = jiwer.wer(reference, hypothesis) return wer_value except Exception: # If all else fails, return 1.0 (maximum error) print("Error calculating WER - fallback to maximum error") return 1.0 # WER calculation if len(ref_words) == 0: return 1.0 return float(distance) / float(len(ref_words)) # Calculate WER for a group of examples with multiple methods def calculate_wer_methods(examples): if not examples: return 0.0, 0.0, 0.0 try: # Check if examples is a Dataset or a list is_dataset = hasattr(examples, 'features') # Get the first example for inspection if is_dataset and len(examples) > 0: example = examples[0] elif not is_dataset and len(examples) > 0: example = examples[0] else: print("No examples found") return np.nan, np.nan, np.nan print("\n===== EXAMPLE DATA INSPECTION =====") print(f"Keys in example: {example.keys()}") # Try different possible field names possible_reference_fields = ["transcription", "reference", "ground_truth", "target"] possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"] for field in possible_reference_fields: if field in example: print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...") for field in possible_hypothesis_fields: if field in example: print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...") # Process each example in the dataset wer_values_no_lm = [] wer_values_lm_ranking = [] wer_values_n_best_correction = [] valid_count = 0 skipped_count = 0 # Determine how to iterate based on type items_to_process = examples if is_dataset: # Limit to first 200 examples for efficiency items_to_process = examples.select(range(min(200, len(examples)))) else: items_to_process = examples[:200] # First 200 examples for i, ex in enumerate(items_to_process): try: # Get reference transcription transcription = ex.get("transcription") if not transcription or not isinstance(transcription, str): skipped_count += 1 continue # Process the reference reference = preprocess_text(transcription) if not reference: skipped_count += 1 continue # Get 1-best hypothesis for baseline input1 = ex.get("input1") if input1 is None and "hypothesis" in ex and ex["hypothesis"]: if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0: input1 = ex["hypothesis"][0] elif isinstance(ex["hypothesis"], str): input1 = ex["hypothesis"] # Get n-best hypotheses for other methods n_best_hypotheses = ex.get("hypothesis", []) # Process and evaluate all methods # Method 1: No LM (1-best ASR output) if input1 and isinstance(input1, str): no_lm_hyp = preprocess_text(input1) if no_lm_hyp: wer_no_lm = calculate_simple_wer(reference, no_lm_hyp) wer_values_no_lm.append(wer_no_lm) # Method 2: LM ranking (best of n-best) if n_best_hypotheses: lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses) if lm_best_hyp: wer_lm = calculate_simple_wer(reference, lm_best_hyp) wer_values_lm_ranking.append(wer_lm) # Method 3: N-best correction (voting among n-best) if n_best_hypotheses: corrected_hyp = correct_hypotheses(n_best_hypotheses) if corrected_hyp: wer_corrected = calculate_simple_wer(reference, corrected_hyp) wer_values_n_best_correction.append(wer_corrected) # Count as valid if at least one method worked if (wer_values_no_lm and i == len(wer_values_no_lm) - 1) or \ (wer_values_lm_ranking and i == len(wer_values_lm_ranking) - 1) or \ (wer_values_n_best_correction and i == len(wer_values_n_best_correction) - 1): valid_count += 1 else: skipped_count += 1 # Print debug info for a few examples if i < 2: print(f"\nExample {i} inspection:") print(f" Reference: '{reference}'") if input1 and isinstance(input1, str): no_lm_hyp = preprocess_text(input1) print(f" No LM (1-best): '{no_lm_hyp}'") if no_lm_hyp: wer = calculate_simple_wer(reference, no_lm_hyp) print(f" No LM WER: {wer:.4f}") if n_best_hypotheses: print(f" N-best count: {len(n_best_hypotheses) if isinstance(n_best_hypotheses, list) else 'not a list'}") lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses) print(f" LM ranking best: '{lm_best_hyp}'") if lm_best_hyp: wer = calculate_simple_wer(reference, lm_best_hyp) print(f" LM ranking WER: {wer:.4f}") corrected_hyp = correct_hypotheses(n_best_hypotheses) print(f" N-best correction: '{corrected_hyp}'") if corrected_hyp: wer = calculate_simple_wer(reference, corrected_hyp) print(f" N-best correction WER: {wer:.4f}") except Exception as ex_error: print(f"Error processing example {i}: {str(ex_error)}") skipped_count += 1 continue # Calculate average WER for each method print(f"\nProcessing summary: Valid pairs: {valid_count}, Skipped: {skipped_count}") no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan print(f"Calculated WERs:") print(f" No LM: {len(wer_values_no_lm)} pairs, avg WER: {no_lm_wer:.4f}") print(f" LM Ranking: {len(wer_values_lm_ranking)} pairs, avg WER: {lm_ranking_wer:.4f}") print(f" N-best Correction: {len(wer_values_n_best_correction)} pairs, avg WER: {n_best_correction_wer:.4f}") return no_lm_wer, lm_ranking_wer, n_best_correction_wer except Exception as e: print(f"Error in calculate_wer: {str(e)}") print(traceback.format_exc()) return np.nan, np.nan, np.nan # Get WER metrics by source def get_wer_metrics(dataset): try: # Print dataset info print(f"\n===== DATASET INFO =====") print(f"Dataset size: {len(dataset)}") print(f"Dataset features: {dataset.features}") # Group examples by source examples_by_source = {} # Process all examples for i, ex in enumerate(dataset): try: source = ex.get("source", "unknown") # Skip all_et05_real as requested if source == "all_et05_real": continue if source not in examples_by_source: examples_by_source[source] = [] examples_by_source[source].append(ex) except Exception as e: print(f"Error processing example {i}: {str(e)}") continue # Get all unique sources all_sources = sorted(examples_by_source.keys()) print(f"Found sources: {all_sources}") # Calculate metrics for each source source_results = {} for source in all_sources: try: examples = examples_by_source.get(source, []) count = len(examples) if count > 0: print(f"\nCalculating WER for source {source} with {count} examples") no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples) else: no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan source_results[source] = { "Count": count, "No LM Baseline": no_lm_wer, "N-best LM Ranking": lm_ranking_wer, "N-best Correction": n_best_wer } except Exception as e: print(f"Error processing source {source}: {str(e)}") source_results[source] = { "Count": 0, "No LM Baseline": np.nan, "N-best LM Ranking": np.nan, "N-best Correction": np.nan } # Calculate overall metrics with a sample but excluding all_et05_real try: # Create a filtered dataset without all_et05_real filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"] total_count = len(filtered_dataset) print(f"\nCalculating overall WER with a sample of examples (excluding all_et05_real)") # Sample for calculation sample_size = min(500, total_count) sample_dataset = filtered_dataset[:sample_size] no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset) source_results["OVERALL"] = { "Count": total_count, "No LM Baseline": no_lm_wer, "N-best LM Ranking": lm_ranking_wer, "N-best Correction": n_best_wer } except Exception as e: print(f"Error calculating overall metrics: {str(e)}") print(traceback.format_exc()) source_results["OVERALL"] = { "Count": len(filtered_dataset), "No LM Baseline": np.nan, "N-best LM Ranking": np.nan, "N-best Correction": np.nan } # Create flat DataFrame with labels in the first column rows = [] # First add row for number of examples example_row = {"Metric": "Number of Examples"} for source in all_sources + ["OVERALL"]: example_row[source] = source_results[source]["Count"] rows.append(example_row) # Then add rows for each WER method no_lm_row = {"Metric": "Word Error Rate (No LM)"} lm_ranking_row = {"Metric": "Word Error Rate (N-best LM Ranking)"} n_best_row = {"Metric": "Word Error Rate (N-best Correction)"} for source in all_sources + ["OVERALL"]: no_lm_row[source] = source_results[source]["No LM Baseline"] lm_ranking_row[source] = source_results[source]["N-best LM Ranking"] n_best_row[source] = source_results[source]["N-best Correction"] rows.append(no_lm_row) rows.append(lm_ranking_row) rows.append(n_best_row) # Create DataFrame from rows result_df = pd.DataFrame(rows) return result_df except Exception as e: print(f"Error in get_wer_metrics: {str(e)}") print(traceback.format_exc()) return pd.DataFrame([{"Error": str(e)}]) # Format the dataframe for display def format_dataframe(df): try: # Use vectorized operations instead of apply df = df.copy() # Find the rows containing WER values wer_row_indices = [] for i, metric in enumerate(df["Metric"]): if "WER" in metric or "Error Rate" in metric: wer_row_indices.append(i) # Format WER values for idx in wer_row_indices: for col in df.columns: if col != "Metric": # Skip the metric column value = df.loc[idx, col] if pd.notna(value): df.loc[idx, col] = f"{value:.4f}" else: df.loc[idx, col] = "N/A" return df except Exception as e: print(f"Error in format_dataframe: {str(e)}") print(traceback.format_exc()) return pd.DataFrame([{"Error": str(e)}]) # Main function to create the leaderboard def create_leaderboard(): try: dataset = load_data() metrics_df = get_wer_metrics(dataset) return format_dataframe(metrics_df) except Exception as e: error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}" print(error_msg) return pd.DataFrame([{"Error": error_msg}]) # Create the Gradio interface with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo: gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)") gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches") with gr.Row(): refresh_btn = gr.Button("Refresh Leaderboard") with gr.Row(): error_output = gr.Textbox(label="Debug Information", visible=True, lines=10) with gr.Row(): try: initial_df = create_leaderboard() leaderboard = gr.DataFrame(initial_df) except Exception as e: error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}" print(error_msg) error_output.update(value=error_msg) leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}])) def refresh_and_report(): try: df = create_leaderboard() debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information." return df, debug_info except Exception as e: error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}" print(error_msg) return pd.DataFrame([{"Error": error_msg}]), error_msg refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output]) if __name__ == "__main__": demo.launch()