WER / app.py
piyushmaharana's picture
working demo for wer
92868fb
import gradio as gr
import jiwer
import pandas as pd
import logging
from typing import List, Optional, Tuple, Dict
# Set up logging configuration
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
force=True,
handlers=[
logging.StreamHandler(),
]
)
logger = logging.getLogger(__name__)
def calculate_wer_metrics(
hypothesis: str,
reference: str,
normalize: bool = True,
words_to_filter: Optional[List[str]] = None
) -> Dict:
"""
Calculate WER metrics between hypothesis and reference texts.
Args:
hypothesis (str): The hypothesis text
reference (str): The reference text
normalize (bool): Whether to normalize texts before comparison
words_to_filter (List[str], optional): Words to filter out before comparison
Returns:
dict: Dictionary containing WER metrics
Raises:
ValueError: If inputs are invalid or result in empty text after processing
"""
logger.info(f"Calculating WER metrics with inputs - Hypothesis: {hypothesis}, Reference: {reference}")
# Validate inputs
if not hypothesis.strip() or not reference.strip():
raise ValueError("Both hypothesis and reference texts must contain non-empty strings")
if normalize:
# Define basic transformations
basic_transform = jiwer.Compose([
jiwer.ExpandCommonEnglishContractions(),
jiwer.ToLowerCase(),
jiwer.RemoveMultipleSpaces(),
jiwer.RemovePunctuation(),
jiwer.Strip(),
jiwer.ReduceToListOfListOfWords()
])
if words_to_filter and any(words_to_filter):
def filter_words_transform(words: List[str]) -> List[str]:
filtered = [word for word in words
if word.lower() not in [w.lower() for w in words_to_filter]]
if not filtered:
raise ValueError("Text is empty after filtering words")
return filtered
transformation = jiwer.Compose([
basic_transform,
filter_words_transform
])
else:
transformation = basic_transform
# Pre-check the transformed text
try:
transformed_ref = transformation(reference)
transformed_hyp = transformation(hypothesis)
if not transformed_ref or not transformed_hyp:
raise ValueError("Text is empty after normalization")
logger.debug(f"Transformed reference: {transformed_ref}")
logger.debug(f"Transformed hypothesis: {transformed_hyp}")
except Exception as e:
logger.error(f"Transformation error: {str(e)}")
raise ValueError(f"Error during text transformation: {str(e)}")
measures = jiwer.compute_measures(
truth=reference,
hypothesis=hypothesis,
truth_transform=transformation,
hypothesis_transform=transformation
)
else:
measures = jiwer.compute_measures(
truth=reference,
hypothesis=hypothesis
)
return measures
def process_inputs(
reference: str,
hypothesis: str,
normalize: bool,
words_to_filter: str
) -> Tuple[str, str, str, str]:
"""
Process inputs and calculate WER metrics.
Args:
reference (str): Reference text
hypothesis (str): Hypothesis text
normalize (bool): Whether to normalize text
words_to_filter (str): Comma-separated words to filter
Returns:
Tuple[str, str, str, str]: HTML formatted main metrics, error analysis,
and explanations
"""
if not reference or not hypothesis:
return "Please provide both reference and hypothesis texts.", "", "", ""
try:
filter_words = [word.strip() for word in words_to_filter.split(",")] if words_to_filter else None
measures = calculate_wer_metrics(
hypothesis=hypothesis,
reference=reference,
normalize=normalize,
words_to_filter=filter_words
)
# Format main metrics
metrics_df = pd.DataFrame({
'Metric': ['WER', 'MER', 'WIL', 'WIP'],
'Value': [
f"{measures['wer']:.3f}",
f"{measures['mer']:.3f}",
f"{measures['wil']:.3f}",
f"{measures['wip']:.3f}"
]
})
# Format error analysis
error_df = pd.DataFrame({
'Metric': ['Substitutions', 'Deletions', 'Insertions', 'Hits'],
'Count': [
measures['substitutions'],
measures['deletions'],
measures['insertions'],
measures['hits']
]
})
metrics_html = metrics_df.to_html(index=False)
error_html = error_df.to_html(index=False)
explanation = """
<h3>Metrics Explanation:</h3>
<ul>
<li><b>WER (Word Error Rate)</b>: The percentage of words that were incorrectly predicted</li>
<li><b>MER (Match Error Rate)</b>: The percentage of words that were incorrectly matched</li>
<li><b>WIL (Word Information Lost)</b>: The percentage of word information that was lost</li>
<li><b>WIP (Word Information Preserved)</b>: The percentage of word information that was preserved</li>
</ul>
"""
return metrics_html, error_html, explanation, ""
except Exception as e:
error_msg = f"Error calculating WER: {str(e)}"
logger.error(error_msg)
return "", "", "", error_msg
def load_example() -> Tuple[str, str]:
"""Load example texts for demonstration."""
return (
"the quick brown fox jumps over the lazy dog",
"the quick brown fox jumped over lazy dog"
)
def create_interface() -> gr.Blocks:
"""Create the Gradio interface."""
with gr.Blocks(title="WER Evaluation Tool") as interface:
gr.Markdown("# Word Error Rate (WER) Evaluation Tool")
gr.Markdown(
"This tool helps you evaluate the Word Error Rate (WER) between a reference "
"text and a hypothesis text. WER is commonly used in speech recognition and "
"machine translation evaluation."
)
with gr.Row():
with gr.Column():
reference = gr.Textbox(
label="Reference Text",
placeholder="Enter the reference text here...",
lines=5
)
with gr.Column():
hypothesis = gr.Textbox(
label="Hypothesis Text",
placeholder="Enter the hypothesis text here...",
lines=5
)
with gr.Row():
normalize = gr.Checkbox(
label="Normalize text (lowercase, remove punctuation)",
value=True
)
words_to_filter = gr.Textbox(
label="Words to filter (comma-separated)",
placeholder="e.g., um, uh, ah"
)
with gr.Row():
example_btn = gr.Button("Load Example")
calculate_btn = gr.Button("Calculate WER", variant="primary")
with gr.Row():
metrics_output = gr.HTML(label="Main Metrics")
error_output = gr.HTML(label="Error Analysis")
explanation_output = gr.HTML()
error_msg_output = gr.HTML()
# Event handlers
example_btn.click(
load_example,
outputs=[reference, hypothesis]
)
calculate_btn.click(
process_inputs,
inputs=[reference, hypothesis, normalize, words_to_filter],
outputs=[metrics_output, error_output, explanation_output, error_msg_output]
)
return interface
if __name__ == "__main__":
logger.info("Application started")
app = create_interface()
app.launch()