|
import argparse |
|
from typing import Dict, Tuple |
|
|
|
from visualize import generate_html_report |
|
from utils import read_srt_text, preprocess_chinese_text |
|
import jiwer |
|
|
|
|
|
from utils import read_srt_text |
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser( |
|
description="Calculate Character Error Rate (CER) for Chinese SRT files", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
parser.add_argument( |
|
"-r", |
|
"--reference", |
|
required=True, |
|
help="Path to the reference (golden) SRT file", |
|
) |
|
|
|
parser.add_argument( |
|
"-i", "--input", required=True, help="Path to the input (hypothesis) SRT file" |
|
) |
|
|
|
parser.add_argument("-o", "--output", help="Path to save the results (optional)") |
|
|
|
parser.add_argument("--html", help="Path to save the HTML visualization (optional)") |
|
|
|
return parser.parse_args() |
|
|
|
def calculate_cer_both_versions( |
|
reference_path: str, hypothesis_path: str |
|
) -> Tuple[Dict, Dict]: |
|
""" |
|
Calculate CER and related metrics between reference and hypothesis SRT files, |
|
both with and without punctuation. |
|
|
|
Args: |
|
reference_path (str): Path to the reference SRT file |
|
hypothesis_path (str): Path to the hypothesis SRT file |
|
|
|
Returns: |
|
Tuple[Dict, Dict]: Two dictionaries containing metrics (with and without punctuation) |
|
""" |
|
|
|
reference_text = read_srt_text(reference_path) |
|
hypothesis_text = read_srt_text(hypothesis_path) |
|
|
|
|
|
reference_chars_no_punct = preprocess_chinese_text( |
|
reference_text, include_punctuation=False |
|
) |
|
hypothesis_chars_no_punct = preprocess_chinese_text( |
|
hypothesis_text, include_punctuation=False |
|
) |
|
|
|
metrics_no_punct = jiwer.compute_measures( |
|
reference_chars_no_punct, hypothesis_chars_no_punct |
|
) |
|
|
|
|
|
reference_chars_with_punct = preprocess_chinese_text( |
|
reference_text, include_punctuation=True |
|
) |
|
hypothesis_chars_with_punct = preprocess_chinese_text( |
|
hypothesis_text, include_punctuation=True |
|
) |
|
|
|
metrics_with_punct = jiwer.compute_measures( |
|
reference_chars_with_punct, hypothesis_chars_with_punct |
|
) |
|
|
|
|
|
metrics_no_punct["total_ref_chars"] = len(reference_chars_no_punct.replace(" ", "")) |
|
metrics_no_punct["total_hyp_chars"] = len( |
|
hypothesis_chars_no_punct.replace(" ", "") |
|
) |
|
metrics_with_punct["total_ref_chars"] = len( |
|
reference_chars_with_punct.replace(" ", "") |
|
) |
|
metrics_with_punct["total_hyp_chars"] = len( |
|
hypothesis_chars_with_punct.replace(" ", "") |
|
) |
|
|
|
return metrics_no_punct, metrics_with_punct |
|
|
|
|
|
def format_metrics(metrics: dict, version: str) -> str: |
|
""" |
|
Format metrics into a string. |
|
|
|
Args: |
|
metrics (dict): Dictionary of metric values |
|
version (str): String indicating which version of metrics these are |
|
|
|
Returns: |
|
str: Formatted metrics string |
|
""" |
|
output = [] |
|
output.append(f"\n=== {version} ===") |
|
output.append(f"Character Error Rate (CER): {metrics['wer']:.3f}") |
|
output.append(f"Total Reference Characters: {metrics['total_ref_chars']}") |
|
output.append(f"Total Hypothesis Characters: {metrics['total_hyp_chars']}") |
|
|
|
output.append("\nDetailed Statistics:") |
|
output.append(f"Correct Characters: {metrics['hits']}") |
|
output.append(f"Substitutions: {metrics['substitutions']}") |
|
output.append(f"Deletions: {metrics['deletions']}") |
|
output.append(f"Insertions: {metrics['insertions']}") |
|
|
|
|
|
total_errors = ( |
|
metrics["substitutions"] + metrics["deletions"] + metrics["insertions"] |
|
) |
|
total_chars = metrics["total_ref_chars"] |
|
|
|
output.append(f"\nError Analysis:") |
|
output.append(f"Total Errors: {total_errors}") |
|
output.append(f"Substitution Rate: {metrics['substitutions']/total_chars:.3f}") |
|
output.append(f"Deletion Rate: {metrics['deletions']/total_chars:.3f}") |
|
output.append(f"Insertion Rate: {metrics['insertions']/total_chars:.3f}") |
|
|
|
return "\n".join(output) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_arguments() |
|
|
|
try: |
|
|
|
reference_text = read_srt_text(args.reference) |
|
hypothesis_text = read_srt_text(args.input) |
|
|
|
|
|
metrics_no_punct, metrics_with_punct = calculate_cer_both_versions( |
|
args.reference, args.input |
|
) |
|
|
|
|
|
if args.html: |
|
html_content = generate_html_report( |
|
reference_text, hypothesis_text, metrics_no_punct, metrics_with_punct |
|
) |
|
with open(args.html, "w", encoding="utf-8") as f: |
|
f.write(html_content) |
|
print(f"\nHTML visualization has been saved to: {args.html}") |
|
|
|
|
|
output_text = [] |
|
output_text.append( |
|
format_metrics(metrics_no_punct, "Metrics Without Punctuation") |
|
) |
|
output_text.append( |
|
format_metrics(metrics_with_punct, "Metrics With Punctuation") |
|
) |
|
output_text.append("\n=== Comparison ===") |
|
output_text.append(f"CER without punctuation: {metrics_no_punct['wer']:.3f}") |
|
output_text.append(f"CER with punctuation: {metrics_with_punct['wer']:.3f}") |
|
output_text.append( |
|
f"Difference: {abs(metrics_with_punct['wer'] - metrics_no_punct['wer']):.3f}" |
|
) |
|
|
|
final_output = "\n".join(output_text) |
|
print(final_output) |
|
|
|
if args.output: |
|
with open(args.output, "w", encoding="utf-8") as f: |
|
f.write(final_output) |
|
print(f"\nResults have been saved to: {args.output}") |
|
|
|
except FileNotFoundError as e: |
|
print(f"Error: Could not find one of the input files - {str(e)}") |
|
except Exception as e: |
|
print(f"Error occurred: {str(e)}") |
|
|