srt-eval / main.py
laubonghaudoi's picture
initial commit
a233921
raw
history blame
6.06 kB
import argparse
from typing import Dict, Tuple
from visualize import generate_html_report
from utils import read_srt_text, preprocess_chinese_text
import jiwer
# Expose read_srt_text function
from utils import read_srt_text
def parse_arguments():
parser = argparse.ArgumentParser(
description="Calculate Character Error Rate (CER) for Chinese SRT files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-r",
"--reference",
required=True,
help="Path to the reference (golden) SRT file",
)
parser.add_argument(
"-i", "--input", required=True, help="Path to the input (hypothesis) SRT file"
)
parser.add_argument("-o", "--output", help="Path to save the results (optional)")
parser.add_argument("--html", help="Path to save the HTML visualization (optional)")
return parser.parse_args()
def calculate_cer_both_versions(
reference_path: str, hypothesis_path: str
) -> Tuple[Dict, Dict]:
"""
Calculate CER and related metrics between reference and hypothesis SRT files,
both with and without punctuation.
Args:
reference_path (str): Path to the reference SRT file
hypothesis_path (str): Path to the hypothesis SRT file
Returns:
Tuple[Dict, Dict]: Two dictionaries containing metrics (with and without punctuation)
"""
# Read files
reference_text = read_srt_text(reference_path)
hypothesis_text = read_srt_text(hypothesis_path)
# Calculate metrics without punctuation
reference_chars_no_punct = preprocess_chinese_text(
reference_text, include_punctuation=False
)
hypothesis_chars_no_punct = preprocess_chinese_text(
hypothesis_text, include_punctuation=False
)
metrics_no_punct = jiwer.compute_measures(
reference_chars_no_punct, hypothesis_chars_no_punct
)
# Calculate metrics with punctuation
reference_chars_with_punct = preprocess_chinese_text(
reference_text, include_punctuation=True
)
hypothesis_chars_with_punct = preprocess_chinese_text(
hypothesis_text, include_punctuation=True
)
metrics_with_punct = jiwer.compute_measures(
reference_chars_with_punct, hypothesis_chars_with_punct
)
# Add character counts
metrics_no_punct["total_ref_chars"] = len(reference_chars_no_punct.replace(" ", ""))
metrics_no_punct["total_hyp_chars"] = len(
hypothesis_chars_no_punct.replace(" ", "")
)
metrics_with_punct["total_ref_chars"] = len(
reference_chars_with_punct.replace(" ", "")
)
metrics_with_punct["total_hyp_chars"] = len(
hypothesis_chars_with_punct.replace(" ", "")
)
return metrics_no_punct, metrics_with_punct
def format_metrics(metrics: dict, version: str) -> str:
"""
Format metrics into a string.
Args:
metrics (dict): Dictionary of metric values
version (str): String indicating which version of metrics these are
Returns:
str: Formatted metrics string
"""
output = []
output.append(f"\n=== {version} ===")
output.append(f"Character Error Rate (CER): {metrics['wer']:.3f}")
output.append(f"Total Reference Characters: {metrics['total_ref_chars']}")
output.append(f"Total Hypothesis Characters: {metrics['total_hyp_chars']}")
output.append("\nDetailed Statistics:")
output.append(f"Correct Characters: {metrics['hits']}")
output.append(f"Substitutions: {metrics['substitutions']}")
output.append(f"Deletions: {metrics['deletions']}")
output.append(f"Insertions: {metrics['insertions']}")
# Calculate and print percentage stats
total_errors = (
metrics["substitutions"] + metrics["deletions"] + metrics["insertions"]
)
total_chars = metrics["total_ref_chars"]
output.append(f"\nError Analysis:")
output.append(f"Total Errors: {total_errors}")
output.append(f"Substitution Rate: {metrics['substitutions']/total_chars:.3f}")
output.append(f"Deletion Rate: {metrics['deletions']/total_chars:.3f}")
output.append(f"Insertion Rate: {metrics['insertions']/total_chars:.3f}")
return "\n".join(output)
if __name__ == "__main__":
args = parse_arguments()
try:
# Read the original texts
reference_text = read_srt_text(args.reference)
hypothesis_text = read_srt_text(args.input)
# Calculate metrics
metrics_no_punct, metrics_with_punct = calculate_cer_both_versions(
args.reference, args.input
)
# Generate and save HTML report if requested
if args.html:
html_content = generate_html_report(
reference_text, hypothesis_text, metrics_no_punct, metrics_with_punct
)
with open(args.html, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"\nHTML visualization has been saved to: {args.html}")
# Original metrics output
output_text = []
output_text.append(
format_metrics(metrics_no_punct, "Metrics Without Punctuation")
)
output_text.append(
format_metrics(metrics_with_punct, "Metrics With Punctuation")
)
output_text.append("\n=== Comparison ===")
output_text.append(f"CER without punctuation: {metrics_no_punct['wer']:.3f}")
output_text.append(f"CER with punctuation: {metrics_with_punct['wer']:.3f}")
output_text.append(
f"Difference: {abs(metrics_with_punct['wer'] - metrics_no_punct['wer']):.3f}"
)
final_output = "\n".join(output_text)
print(final_output)
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
f.write(final_output)
print(f"\nResults have been saved to: {args.output}")
except FileNotFoundError as e:
print(f"Error: Could not find one of the input files - {str(e)}")
except Exception as e:
print(f"Error occurred: {str(e)}")