import gradio as gr import pandas as pd from datasets import load_dataset import jiwer import numpy as np # Load the dataset def load_data(): dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction") return dataset # Calculate WER for a group of examples def calculate_wer(examples): if not examples: return 0.0 hypotheses = [ex["hypothesis_concatenated"].split('.')[0].strip() for ex in examples] transcriptions = [ex["transcription"].strip() for ex in examples] wer = jiwer.wer(transcriptions, hypotheses) return wer # Get WER metrics by source and split def get_wer_metrics(dataset): results = [] # Get unique sources train_sources = set([ex["source"] for ex in dataset["train"]]) test_sources = set([ex["source"] for ex in dataset["test"]]) all_sources = sorted(list(train_sources.union(test_sources))) # Calculate WER for each source in train split for source in all_sources: train_examples = [ex for ex in dataset["train"] if ex["source"] == source] train_count = len(train_examples) train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan test_examples = [ex for ex in dataset["test"] if ex["source"] == source] test_count = len(test_examples) test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan results.append({ "Source": source, "Train Count": train_count, "Train WER": train_wer, "Test Count": test_count, "Test WER": test_wer }) # Add overall metrics train_wer = calculate_wer(dataset["train"]) test_wer = calculate_wer(dataset["test"]) results.append({ "Source": "OVERALL", "Train Count": len(dataset["train"]), "Train WER": train_wer, "Test Count": len(dataset["test"]), "Test WER": test_wer }) return pd.DataFrame(results) # Format the dataframe for display def format_dataframe(df): df["Train WER"] = df["Train WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A") df["Test WER"] = df["Test WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A") return df # Main function to create the leaderboard def create_leaderboard(): try: dataset = load_data() metrics_df = get_wer_metrics(dataset) formatted_df = format_dataframe(metrics_df) return formatted_df except Exception as e: return pd.DataFrame({"Error": [str(e)]}) # Create the Gradio interface with gr.Blocks(title="ASR Text Correction Leaderboard") as demo: gr.Markdown("# ASR Text Correction Baseline WER Leaderboard") gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset") with gr.Row(): refresh_btn = gr.Button("Refresh Leaderboard") with gr.Row(): leaderboard = gr.DataFrame(create_leaderboard()) refresh_btn.click(create_leaderboard, outputs=leaderboard) if __name__ == "__main__": demo.launch()