Spaces:
Sleeping
Sleeping
import gradio as gr | |
import jiwer | |
import pandas as pd | |
import logging | |
from typing import List, Optional, Tuple, Dict | |
# Set up logging configuration | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
force=True, | |
handlers=[ | |
logging.StreamHandler(), | |
] | |
) | |
logger = logging.getLogger(__name__) | |
def calculate_wer_metrics( | |
hypothesis: str, | |
reference: str, | |
normalize: bool = True, | |
words_to_filter: Optional[List[str]] = None | |
) -> Dict: | |
""" | |
Calculate WER metrics between hypothesis and reference texts. | |
Args: | |
hypothesis (str): The hypothesis text | |
reference (str): The reference text | |
normalize (bool): Whether to normalize texts before comparison | |
words_to_filter (List[str], optional): Words to filter out before comparison | |
Returns: | |
dict: Dictionary containing WER metrics | |
Raises: | |
ValueError: If inputs are invalid or result in empty text after processing | |
""" | |
logger.info(f"Calculating WER metrics with inputs - Hypothesis: {hypothesis}, Reference: {reference}") | |
# Validate inputs | |
if not hypothesis.strip() or not reference.strip(): | |
raise ValueError("Both hypothesis and reference texts must contain non-empty strings") | |
if normalize: | |
# Define basic transformations | |
basic_transform = jiwer.Compose([ | |
jiwer.ExpandCommonEnglishContractions(), | |
jiwer.ToLowerCase(), | |
jiwer.RemoveMultipleSpaces(), | |
jiwer.RemovePunctuation(), | |
jiwer.Strip(), | |
jiwer.ReduceToListOfListOfWords() | |
]) | |
if words_to_filter and any(words_to_filter): | |
def filter_words_transform(words: List[str]) -> List[str]: | |
filtered = [word for word in words | |
if word.lower() not in [w.lower() for w in words_to_filter]] | |
if not filtered: | |
raise ValueError("Text is empty after filtering words") | |
return filtered | |
transformation = jiwer.Compose([ | |
basic_transform, | |
filter_words_transform | |
]) | |
else: | |
transformation = basic_transform | |
# Pre-check the transformed text | |
try: | |
transformed_ref = transformation(reference) | |
transformed_hyp = transformation(hypothesis) | |
if not transformed_ref or not transformed_hyp: | |
raise ValueError("Text is empty after normalization") | |
logger.debug(f"Transformed reference: {transformed_ref}") | |
logger.debug(f"Transformed hypothesis: {transformed_hyp}") | |
except Exception as e: | |
logger.error(f"Transformation error: {str(e)}") | |
raise ValueError(f"Error during text transformation: {str(e)}") | |
measures = jiwer.compute_measures( | |
truth=reference, | |
hypothesis=hypothesis, | |
truth_transform=transformation, | |
hypothesis_transform=transformation | |
) | |
else: | |
measures = jiwer.compute_measures( | |
truth=reference, | |
hypothesis=hypothesis | |
) | |
return measures | |
def process_inputs( | |
reference: str, | |
hypothesis: str, | |
normalize: bool, | |
words_to_filter: str | |
) -> Tuple[str, str, str, str]: | |
""" | |
Process inputs and calculate WER metrics. | |
Args: | |
reference (str): Reference text | |
hypothesis (str): Hypothesis text | |
normalize (bool): Whether to normalize text | |
words_to_filter (str): Comma-separated words to filter | |
Returns: | |
Tuple[str, str, str, str]: HTML formatted main metrics, error analysis, | |
and explanations | |
""" | |
if not reference or not hypothesis: | |
return "Please provide both reference and hypothesis texts.", "", "", "" | |
try: | |
filter_words = [word.strip() for word in words_to_filter.split(",")] if words_to_filter else None | |
measures = calculate_wer_metrics( | |
hypothesis=hypothesis, | |
reference=reference, | |
normalize=normalize, | |
words_to_filter=filter_words | |
) | |
# Format main metrics | |
metrics_df = pd.DataFrame({ | |
'Metric': ['WER', 'MER', 'WIL', 'WIP'], | |
'Value': [ | |
f"{measures['wer']:.3f}", | |
f"{measures['mer']:.3f}", | |
f"{measures['wil']:.3f}", | |
f"{measures['wip']:.3f}" | |
] | |
}) | |
# Format error analysis | |
error_df = pd.DataFrame({ | |
'Metric': ['Substitutions', 'Deletions', 'Insertions', 'Hits'], | |
'Count': [ | |
measures['substitutions'], | |
measures['deletions'], | |
measures['insertions'], | |
measures['hits'] | |
] | |
}) | |
metrics_html = metrics_df.to_html(index=False) | |
error_html = error_df.to_html(index=False) | |
explanation = """ | |
<h3>Metrics Explanation:</h3> | |
<ul> | |
<li><b>WER (Word Error Rate)</b>: The percentage of words that were incorrectly predicted</li> | |
<li><b>MER (Match Error Rate)</b>: The percentage of words that were incorrectly matched</li> | |
<li><b>WIL (Word Information Lost)</b>: The percentage of word information that was lost</li> | |
<li><b>WIP (Word Information Preserved)</b>: The percentage of word information that was preserved</li> | |
</ul> | |
""" | |
return metrics_html, error_html, explanation, "" | |
except Exception as e: | |
error_msg = f"Error calculating WER: {str(e)}" | |
logger.error(error_msg) | |
return "", "", "", error_msg | |
def load_example() -> Tuple[str, str]: | |
"""Load example texts for demonstration.""" | |
return ( | |
"the quick brown fox jumps over the lazy dog", | |
"the quick brown fox jumped over lazy dog" | |
) | |
def create_interface() -> gr.Blocks: | |
"""Create the Gradio interface.""" | |
with gr.Blocks(title="WER Evaluation Tool") as interface: | |
gr.Markdown("# Word Error Rate (WER) Evaluation Tool") | |
gr.Markdown( | |
"This tool helps you evaluate the Word Error Rate (WER) between a reference " | |
"text and a hypothesis text. WER is commonly used in speech recognition and " | |
"machine translation evaluation." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
reference = gr.Textbox( | |
label="Reference Text", | |
placeholder="Enter the reference text here...", | |
lines=5 | |
) | |
with gr.Column(): | |
hypothesis = gr.Textbox( | |
label="Hypothesis Text", | |
placeholder="Enter the hypothesis text here...", | |
lines=5 | |
) | |
with gr.Row(): | |
normalize = gr.Checkbox( | |
label="Normalize text (lowercase, remove punctuation)", | |
value=True | |
) | |
words_to_filter = gr.Textbox( | |
label="Words to filter (comma-separated)", | |
placeholder="e.g., um, uh, ah" | |
) | |
with gr.Row(): | |
example_btn = gr.Button("Load Example") | |
calculate_btn = gr.Button("Calculate WER", variant="primary") | |
with gr.Row(): | |
metrics_output = gr.HTML(label="Main Metrics") | |
error_output = gr.HTML(label="Error Analysis") | |
explanation_output = gr.HTML() | |
error_msg_output = gr.HTML() | |
# Event handlers | |
example_btn.click( | |
load_example, | |
outputs=[reference, hypothesis] | |
) | |
calculate_btn.click( | |
process_inputs, | |
inputs=[reference, hypothesis, normalize, words_to_filter], | |
outputs=[metrics_output, error_output, explanation_output, error_msg_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
logger.info("Application started") | |
app = create_interface() | |
app.launch() |