import gradio as gr import pandas as pd from datasets import load_dataset import jiwer import numpy as np from functools import lru_cache import traceback # Cache the dataset loading to avoid reloading on refresh @lru_cache(maxsize=1) def load_data(): try: # Load only the test dataset by specifying the split dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test") return dataset except Exception as e: print(f"Error loading dataset: {str(e)}") # Try loading with explicit file path if the default loading fails try: dataset = load_dataset("parquet", data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet") return dataset except Exception as e2: print(f"Error loading with explicit path: {str(e2)}") raise # Calculate WER for a group of examples def calculate_wer(examples): if not examples: return 0.0 try: # Filter valid examples in a single pass valid_pairs = [] for ex in examples: try: # Print a sample example to debug if len(valid_pairs) == 0: print(f"Sample example keys: {ex.keys()}") transcription = ex.get("transcription", "") input1 = ex.get("input1", "") # Only add valid pairs with non-empty strings if transcription and input1 and isinstance(transcription, str) and isinstance(input1, str): # Limit text length to avoid potential issues transcription = transcription.strip()[:1000] # Limit to 1000 chars input1 = input1.strip()[:1000] valid_pairs.append((transcription, input1)) except Exception as ex_error: # Skip problematic examples but continue processing print(f"Error processing example: {str(ex_error)}") continue if not valid_pairs: print("No valid pairs found for WER calculation") return np.nan # Print sample pairs for debugging print(f"Sample pair for WER calculation: {valid_pairs[0]}") print(f"Total valid pairs: {len(valid_pairs)}") # Unzip the pairs in one operation references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], []) # Calculate WER try: wer = jiwer.wer(references, hypotheses) print(f"Calculated WER: {wer}") return wer except Exception as wer_error: print(f"Error calculating WER: {str(wer_error)}") return np.nan except Exception as e: print(f"Error in calculate_wer: {str(e)}") print(traceback.format_exc()) return np.nan # Get WER metrics by source def get_wer_metrics(dataset): try: # Group examples by source examples_by_source = {} # Process all examples for ex in dataset: try: source = ex.get("source", "unknown") if source not in examples_by_source: examples_by_source[source] = [] examples_by_source[source].append(ex) except Exception as e: print(f"Error processing example: {str(e)}") continue # Get all unique sources all_sources = sorted(examples_by_source.keys()) # Calculate metrics for each source results = [] for source in all_sources: try: examples = examples_by_source.get(source, []) count = len(examples) if count > 0: print(f"Calculating WER for source {source} with {count} examples") wer = calculate_wer(examples) else: wer = np.nan results.append({ "Source": source, "Count": count, "WER": wer }) except Exception as e: print(f"Error processing source {source}: {str(e)}") results.append({ "Source": source, "Count": 0, "WER": np.nan }) # Calculate overall metrics once try: total_count = len(dataset) print(f"Calculating overall WER for {total_count} examples") overall_wer = calculate_wer(dataset) results.append({ "Source": "OVERALL", "Count": total_count, "WER": overall_wer }) except Exception as e: print(f"Error calculating overall metrics: {str(e)}") results.append({ "Source": "OVERALL", "Count": len(dataset), "WER": np.nan }) return pd.DataFrame(results) except Exception as e: print(f"Error in get_wer_metrics: {str(e)}") print(traceback.format_exc()) return pd.DataFrame([{"Error": str(e)}]) # Format the dataframe for display def format_dataframe(df): try: # Use vectorized operations instead of apply df = df.copy() if "WER" in df.columns: mask = df["WER"].notna() df.loc[mask, "WER"] = df.loc[mask, "WER"].map(lambda x: f"{x:.4f}") df.loc[~mask, "WER"] = "N/A" return df except Exception as e: print(f"Error in format_dataframe: {str(e)}") print(traceback.format_exc()) return pd.DataFrame([{"Error": str(e)}]) # Main function to create the leaderboard def create_leaderboard(): try: dataset = load_data() metrics_df = get_wer_metrics(dataset) return format_dataframe(metrics_df) except Exception as e: error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}" print(error_msg) return pd.DataFrame([{"Error": error_msg}]) # Create the Gradio interface with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo: gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)") gr.Markdown("Word Error Rate (WER) metrics for test data in GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset") with gr.Row(): refresh_btn = gr.Button("Refresh Leaderboard") with gr.Row(): error_output = gr.Textbox(label="Debug Information", visible=True) with gr.Row(): try: initial_df = create_leaderboard() leaderboard = gr.DataFrame(initial_df) except Exception as e: error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}" print(error_msg) error_output.update(value=error_msg) leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}])) def refresh_and_report(): try: df = create_leaderboard() debug_info = "Leaderboard refreshed successfully." return df, debug_info except Exception as e: error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}" print(error_msg) return pd.DataFrame([{"Error": error_msg}]), error_msg refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output]) if __name__ == "__main__": demo.launch()