huckiyang's picture
more LM baseline
9f029d4
raw
history blame
20.3 kB
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np
from functools import lru_cache
import traceback
import re
import string
from collections import Counter
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
# Load only the test dataset by specifying the split
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception as e:
print(f"Error loading dataset: {str(e)}")
# Try loading with explicit file path if the default loading fails
try:
dataset = load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
return dataset
except Exception as e2:
print(f"Error loading with explicit path: {str(e2)}")
raise
# Preprocess text for better WER calculation
def preprocess_text(text):
if not text or not isinstance(text, str):
return ""
# Convert to lowercase
text = text.lower()
# Remove punctuation
text = re.sub(r'[^\w\s]', '', text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
# Simple language model scoring - count n-grams
def score_hypothesis(hypothesis, n=4):
"""Score a hypothesis using simple n-gram statistics"""
if not hypothesis:
return 0
words = hypothesis.split()
if len(words) < n:
return len(words) # Just return word count for very short texts
# Count n-grams
ngrams = []
for i in range(len(words) - n + 1):
ngram = ' '.join(words[i:i+n])
ngrams.append(ngram)
# More unique n-grams might indicate better fluency
unique_ngrams = len(set(ngrams))
total_ngrams = len(ngrams)
# Score is a combination of length and n-gram variety
score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
return score
# N-best LM ranking approach
def get_best_hypothesis_lm(hypotheses):
"""Choose the best hypothesis using a simple language model approach"""
if not hypotheses:
return ""
# Convert to list if it's not already
if isinstance(hypotheses, str):
return hypotheses
# Ensure we have a list of strings
hypothesis_list = []
for h in hypotheses:
if isinstance(h, str):
hypothesis_list.append(preprocess_text(h))
if not hypothesis_list:
return ""
# Score each hypothesis and choose the best one
scores = [(score_hypothesis(h), h) for h in hypothesis_list]
best_hypothesis = max(scores, key=lambda x: x[0])[1]
return best_hypothesis
# N-best correction approach
def correct_hypotheses(hypotheses):
"""Simple n-best correction by voting on words"""
if not hypotheses:
return ""
# Convert to list if it's not already
if isinstance(hypotheses, str):
return hypotheses
# Ensure we have a list of strings
hypothesis_list = []
for h in hypotheses:
if isinstance(h, str):
hypothesis_list.append(preprocess_text(h))
if not hypothesis_list:
return ""
# Split hypotheses into words
word_lists = [h.split() for h in hypothesis_list]
# Find the most common length
lengths = [len(words) for words in word_lists]
if not lengths:
return ""
most_common_length = Counter(lengths).most_common(1)[0][0]
# Only consider hypotheses with the most common length
filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
if not filtered_word_lists:
# Fall back to the longest hypothesis if filtering removed everything
return max(hypothesis_list, key=len)
# Vote on each word position
corrected_words = []
for i in range(most_common_length):
position_words = [words[i] for words in filtered_word_lists]
most_common_word = Counter(position_words).most_common(1)[0][0]
corrected_words.append(most_common_word)
# Join the corrected words
return ' '.join(corrected_words)
# Fix the Levenshtein distance calculation to avoid dependence on jiwer internals
def calculate_simple_wer(reference, hypothesis):
"""Calculate WER using a simple word-based approach"""
if not reference or not hypothesis:
return 1.0 # Maximum error if either is empty
# Split into words
ref_words = reference.split()
hyp_words = hypothesis.split()
# Use editdistance package instead of jiwer internals
try:
import editdistance
distance = editdistance.eval(ref_words, hyp_words)
except ImportError:
# Fallback to simple jiwer calculation
try:
# Try using the standard jiwer implementation
wer_value = jiwer.wer(reference, hypothesis)
return wer_value
except Exception:
# If all else fails, return 1.0 (maximum error)
print("Error calculating WER - fallback to maximum error")
return 1.0
# WER calculation
if len(ref_words) == 0:
return 1.0
return float(distance) / float(len(ref_words))
# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples):
if not examples:
return 0.0, 0.0, 0.0
try:
# Check if examples is a Dataset or a list
is_dataset = hasattr(examples, 'features')
# Get the first example for inspection
if is_dataset and len(examples) > 0:
example = examples[0]
elif not is_dataset and len(examples) > 0:
example = examples[0]
else:
print("No examples found")
return np.nan, np.nan, np.nan
print("\n===== EXAMPLE DATA INSPECTION =====")
print(f"Keys in example: {example.keys()}")
# Try different possible field names
possible_reference_fields = ["transcription", "reference", "ground_truth", "target"]
possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"]
for field in possible_reference_fields:
if field in example:
print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...")
for field in possible_hypothesis_fields:
if field in example:
print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...")
# Process each example in the dataset
wer_values_no_lm = []
wer_values_lm_ranking = []
wer_values_n_best_correction = []
valid_count = 0
skipped_count = 0
# Determine how to iterate based on type
items_to_process = examples
if is_dataset:
# Limit to first 200 examples for efficiency
items_to_process = examples.select(range(min(200, len(examples))))
else:
items_to_process = examples[:200] # First 200 examples
for i, ex in enumerate(items_to_process):
try:
# Get reference transcription
transcription = ex.get("transcription")
if not transcription or not isinstance(transcription, str):
skipped_count += 1
continue
# Process the reference
reference = preprocess_text(transcription)
if not reference:
skipped_count += 1
continue
# Get 1-best hypothesis for baseline
input1 = ex.get("input1")
if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
input1 = ex["hypothesis"][0]
elif isinstance(ex["hypothesis"], str):
input1 = ex["hypothesis"]
# Get n-best hypotheses for other methods
n_best_hypotheses = ex.get("hypothesis", [])
# Process and evaluate all methods
# Method 1: No LM (1-best ASR output)
if input1 and isinstance(input1, str):
no_lm_hyp = preprocess_text(input1)
if no_lm_hyp:
wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
wer_values_no_lm.append(wer_no_lm)
# Method 2: LM ranking (best of n-best)
if n_best_hypotheses:
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
if lm_best_hyp:
wer_lm = calculate_simple_wer(reference, lm_best_hyp)
wer_values_lm_ranking.append(wer_lm)
# Method 3: N-best correction (voting among n-best)
if n_best_hypotheses:
corrected_hyp = correct_hypotheses(n_best_hypotheses)
if corrected_hyp:
wer_corrected = calculate_simple_wer(reference, corrected_hyp)
wer_values_n_best_correction.append(wer_corrected)
# Count as valid if at least one method worked
if (wer_values_no_lm and i == len(wer_values_no_lm) - 1) or \
(wer_values_lm_ranking and i == len(wer_values_lm_ranking) - 1) or \
(wer_values_n_best_correction and i == len(wer_values_n_best_correction) - 1):
valid_count += 1
else:
skipped_count += 1
# Print debug info for a few examples
if i < 2:
print(f"\nExample {i} inspection:")
print(f" Reference: '{reference}'")
if input1 and isinstance(input1, str):
no_lm_hyp = preprocess_text(input1)
print(f" No LM (1-best): '{no_lm_hyp}'")
if no_lm_hyp:
wer = calculate_simple_wer(reference, no_lm_hyp)
print(f" No LM WER: {wer:.4f}")
if n_best_hypotheses:
print(f" N-best count: {len(n_best_hypotheses) if isinstance(n_best_hypotheses, list) else 'not a list'}")
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
print(f" LM ranking best: '{lm_best_hyp}'")
if lm_best_hyp:
wer = calculate_simple_wer(reference, lm_best_hyp)
print(f" LM ranking WER: {wer:.4f}")
corrected_hyp = correct_hypotheses(n_best_hypotheses)
print(f" N-best correction: '{corrected_hyp}'")
if corrected_hyp:
wer = calculate_simple_wer(reference, corrected_hyp)
print(f" N-best correction WER: {wer:.4f}")
except Exception as ex_error:
print(f"Error processing example {i}: {str(ex_error)}")
skipped_count += 1
continue
# Calculate average WER for each method
print(f"\nProcessing summary: Valid pairs: {valid_count}, Skipped: {skipped_count}")
no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
print(f"Calculated WERs:")
print(f" No LM: {len(wer_values_no_lm)} pairs, avg WER: {no_lm_wer:.4f}")
print(f" LM Ranking: {len(wer_values_lm_ranking)} pairs, avg WER: {lm_ranking_wer:.4f}")
print(f" N-best Correction: {len(wer_values_n_best_correction)} pairs, avg WER: {n_best_correction_wer:.4f}")
return no_lm_wer, lm_ranking_wer, n_best_correction_wer
except Exception as e:
print(f"Error in calculate_wer: {str(e)}")
print(traceback.format_exc())
return np.nan, np.nan, np.nan
# Get WER metrics by source
def get_wer_metrics(dataset):
try:
# Print dataset info
print(f"\n===== DATASET INFO =====")
print(f"Dataset size: {len(dataset)}")
print(f"Dataset features: {dataset.features}")
# Group examples by source
examples_by_source = {}
# Process all examples
for i, ex in enumerate(dataset):
try:
source = ex.get("source", "unknown")
# Skip all_et05_real as requested
if source == "all_et05_real":
continue
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
except Exception as e:
print(f"Error processing example {i}: {str(e)}")
continue
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
print(f"Found sources: {all_sources}")
# Calculate metrics for each source
source_results = {}
for source in all_sources:
try:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
print(f"\nCalculating WER for source {source} with {count} examples")
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
else:
no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
source_results[source] = {
"Count": count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
except Exception as e:
print(f"Error processing source {source}: {str(e)}")
source_results[source] = {
"Count": 0,
"No LM Baseline": np.nan,
"N-best LM Ranking": np.nan,
"N-best Correction": np.nan
}
# Calculate overall metrics with a sample but excluding all_et05_real
try:
# Create a filtered dataset without all_et05_real
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
total_count = len(filtered_dataset)
print(f"\nCalculating overall WER with a sample of examples (excluding all_et05_real)")
# Sample for calculation
sample_size = min(500, total_count)
sample_dataset = filtered_dataset[:sample_size]
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
source_results["OVERALL"] = {
"Count": total_count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
except Exception as e:
print(f"Error calculating overall metrics: {str(e)}")
print(traceback.format_exc())
source_results["OVERALL"] = {
"Count": len(filtered_dataset),
"No LM Baseline": np.nan,
"N-best LM Ranking": np.nan,
"N-best Correction": np.nan
}
# Create flat DataFrame with labels in the first column
rows = []
# First add row for number of examples
example_row = {"Metric": "Number of Examples"}
for source in all_sources + ["OVERALL"]:
example_row[source] = source_results[source]["Count"]
rows.append(example_row)
# Then add rows for each WER method
no_lm_row = {"Metric": "Word Error Rate (No LM)"}
lm_ranking_row = {"Metric": "Word Error Rate (N-best LM Ranking)"}
n_best_row = {"Metric": "Word Error Rate (N-best Correction)"}
for source in all_sources + ["OVERALL"]:
no_lm_row[source] = source_results[source]["No LM Baseline"]
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
n_best_row[source] = source_results[source]["N-best Correction"]
rows.append(no_lm_row)
rows.append(lm_ranking_row)
rows.append(n_best_row)
# Create DataFrame from rows
result_df = pd.DataFrame(rows)
return result_df
except Exception as e:
print(f"Error in get_wer_metrics: {str(e)}")
print(traceback.format_exc())
return pd.DataFrame([{"Error": str(e)}])
# Format the dataframe for display
def format_dataframe(df):
try:
# Use vectorized operations instead of apply
df = df.copy()
# Find the rows containing WER values
wer_row_indices = []
for i, metric in enumerate(df["Metric"]):
if "WER" in metric or "Error Rate" in metric:
wer_row_indices.append(i)
# Format WER values
for idx in wer_row_indices:
for col in df.columns:
if col != "Metric": # Skip the metric column
value = df.loc[idx, col]
if pd.notna(value):
df.loc[idx, col] = f"{value:.4f}"
else:
df.loc[idx, col] = "N/A"
return df
except Exception as e:
print(f"Error in format_dataframe: {str(e)}")
print(traceback.format_exc())
return pd.DataFrame([{"Error": str(e)}])
# Main function to create the leaderboard
def create_leaderboard():
try:
dataset = load_data()
metrics_df = get_wer_metrics(dataset)
return format_dataframe(metrics_df)
except Exception as e:
error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return pd.DataFrame([{"Error": error_msg}])
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
error_output = gr.Textbox(label="Debug Information", visible=True, lines=10)
with gr.Row():
try:
initial_df = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception as e:
error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
error_output.update(value=error_msg)
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}]))
def refresh_and_report():
try:
df = create_leaderboard()
debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information."
return df, debug_info
except Exception as e:
error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return pd.DataFrame([{"Error": error_msg}]), error_msg
refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output])
if __name__ == "__main__":
demo.launch()