import argparse from typing import Dict, Tuple from visualize import generate_html_report from utils import read_srt_text, preprocess_chinese_text import jiwer # Expose read_srt_text function from utils import read_srt_text def parse_arguments(): parser = argparse.ArgumentParser( description="Calculate Character Error Rate (CER) for Chinese SRT files", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "-r", "--reference", required=True, help="Path to the reference (golden) SRT file", ) parser.add_argument( "-i", "--input", required=True, help="Path to the input (hypothesis) SRT file" ) parser.add_argument("-o", "--output", help="Path to save the results (optional)") parser.add_argument("--html", help="Path to save the HTML visualization (optional)") return parser.parse_args() def calculate_cer_both_versions( reference_path: str, hypothesis_path: str ) -> Tuple[Dict, Dict]: """ Calculate CER and related metrics between reference and hypothesis SRT files, both with and without punctuation. Args: reference_path (str): Path to the reference SRT file hypothesis_path (str): Path to the hypothesis SRT file Returns: Tuple[Dict, Dict]: Two dictionaries containing metrics (with and without punctuation) """ # Read files reference_text = read_srt_text(reference_path) hypothesis_text = read_srt_text(hypothesis_path) # Calculate metrics without punctuation reference_chars_no_punct = preprocess_chinese_text( reference_text, include_punctuation=False ) hypothesis_chars_no_punct = preprocess_chinese_text( hypothesis_text, include_punctuation=False ) metrics_no_punct = jiwer.compute_measures( reference_chars_no_punct, hypothesis_chars_no_punct ) # Calculate metrics with punctuation reference_chars_with_punct = preprocess_chinese_text( reference_text, include_punctuation=True ) hypothesis_chars_with_punct = preprocess_chinese_text( hypothesis_text, include_punctuation=True ) metrics_with_punct = jiwer.compute_measures( reference_chars_with_punct, hypothesis_chars_with_punct ) # Add character counts metrics_no_punct["total_ref_chars"] = len(reference_chars_no_punct.replace(" ", "")) metrics_no_punct["total_hyp_chars"] = len( hypothesis_chars_no_punct.replace(" ", "") ) metrics_with_punct["total_ref_chars"] = len( reference_chars_with_punct.replace(" ", "") ) metrics_with_punct["total_hyp_chars"] = len( hypothesis_chars_with_punct.replace(" ", "") ) return metrics_no_punct, metrics_with_punct def format_metrics(metrics: dict, version: str) -> str: """ Format metrics into a string. Args: metrics (dict): Dictionary of metric values version (str): String indicating which version of metrics these are Returns: str: Formatted metrics string """ output = [] output.append(f"\n=== {version} ===") output.append(f"Character Error Rate (CER): {metrics['wer']:.3f}") output.append(f"Total Reference Characters: {metrics['total_ref_chars']}") output.append(f"Total Hypothesis Characters: {metrics['total_hyp_chars']}") output.append("\nDetailed Statistics:") output.append(f"Correct Characters: {metrics['hits']}") output.append(f"Substitutions: {metrics['substitutions']}") output.append(f"Deletions: {metrics['deletions']}") output.append(f"Insertions: {metrics['insertions']}") # Calculate and print percentage stats total_errors = ( metrics["substitutions"] + metrics["deletions"] + metrics["insertions"] ) total_chars = metrics["total_ref_chars"] output.append(f"\nError Analysis:") output.append(f"Total Errors: {total_errors}") output.append(f"Substitution Rate: {metrics['substitutions']/total_chars:.3f}") output.append(f"Deletion Rate: {metrics['deletions']/total_chars:.3f}") output.append(f"Insertion Rate: {metrics['insertions']/total_chars:.3f}") return "\n".join(output) if __name__ == "__main__": args = parse_arguments() try: # Read the original texts reference_text = read_srt_text(args.reference) hypothesis_text = read_srt_text(args.input) # Calculate metrics metrics_no_punct, metrics_with_punct = calculate_cer_both_versions( args.reference, args.input ) # Generate and save HTML report if requested if args.html: html_content = generate_html_report( reference_text, hypothesis_text, metrics_no_punct, metrics_with_punct ) with open(args.html, "w", encoding="utf-8") as f: f.write(html_content) print(f"\nHTML visualization has been saved to: {args.html}") # Original metrics output output_text = [] output_text.append( format_metrics(metrics_no_punct, "Metrics Without Punctuation") ) output_text.append( format_metrics(metrics_with_punct, "Metrics With Punctuation") ) output_text.append("\n=== Comparison ===") output_text.append(f"CER without punctuation: {metrics_no_punct['wer']:.3f}") output_text.append(f"CER with punctuation: {metrics_with_punct['wer']:.3f}") output_text.append( f"Difference: {abs(metrics_with_punct['wer'] - metrics_no_punct['wer']):.3f}" ) final_output = "\n".join(output_text) print(final_output) if args.output: with open(args.output, "w", encoding="utf-8") as f: f.write(final_output) print(f"\nResults have been saved to: {args.output}") except FileNotFoundError as e: print(f"Error: Could not find one of the input files - {str(e)}") except Exception as e: print(f"Error occurred: {str(e)}")