File size: 6,055 Bytes
a233921 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import argparse
from typing import Dict, Tuple
from visualize import generate_html_report
from utils import read_srt_text, preprocess_chinese_text
import jiwer
# Expose read_srt_text function
from utils import read_srt_text
def parse_arguments():
parser = argparse.ArgumentParser(
description="Calculate Character Error Rate (CER) for Chinese SRT files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-r",
"--reference",
required=True,
help="Path to the reference (golden) SRT file",
)
parser.add_argument(
"-i", "--input", required=True, help="Path to the input (hypothesis) SRT file"
)
parser.add_argument("-o", "--output", help="Path to save the results (optional)")
parser.add_argument("--html", help="Path to save the HTML visualization (optional)")
return parser.parse_args()
def calculate_cer_both_versions(
reference_path: str, hypothesis_path: str
) -> Tuple[Dict, Dict]:
"""
Calculate CER and related metrics between reference and hypothesis SRT files,
both with and without punctuation.
Args:
reference_path (str): Path to the reference SRT file
hypothesis_path (str): Path to the hypothesis SRT file
Returns:
Tuple[Dict, Dict]: Two dictionaries containing metrics (with and without punctuation)
"""
# Read files
reference_text = read_srt_text(reference_path)
hypothesis_text = read_srt_text(hypothesis_path)
# Calculate metrics without punctuation
reference_chars_no_punct = preprocess_chinese_text(
reference_text, include_punctuation=False
)
hypothesis_chars_no_punct = preprocess_chinese_text(
hypothesis_text, include_punctuation=False
)
metrics_no_punct = jiwer.compute_measures(
reference_chars_no_punct, hypothesis_chars_no_punct
)
# Calculate metrics with punctuation
reference_chars_with_punct = preprocess_chinese_text(
reference_text, include_punctuation=True
)
hypothesis_chars_with_punct = preprocess_chinese_text(
hypothesis_text, include_punctuation=True
)
metrics_with_punct = jiwer.compute_measures(
reference_chars_with_punct, hypothesis_chars_with_punct
)
# Add character counts
metrics_no_punct["total_ref_chars"] = len(reference_chars_no_punct.replace(" ", ""))
metrics_no_punct["total_hyp_chars"] = len(
hypothesis_chars_no_punct.replace(" ", "")
)
metrics_with_punct["total_ref_chars"] = len(
reference_chars_with_punct.replace(" ", "")
)
metrics_with_punct["total_hyp_chars"] = len(
hypothesis_chars_with_punct.replace(" ", "")
)
return metrics_no_punct, metrics_with_punct
def format_metrics(metrics: dict, version: str) -> str:
"""
Format metrics into a string.
Args:
metrics (dict): Dictionary of metric values
version (str): String indicating which version of metrics these are
Returns:
str: Formatted metrics string
"""
output = []
output.append(f"\n=== {version} ===")
output.append(f"Character Error Rate (CER): {metrics['wer']:.3f}")
output.append(f"Total Reference Characters: {metrics['total_ref_chars']}")
output.append(f"Total Hypothesis Characters: {metrics['total_hyp_chars']}")
output.append("\nDetailed Statistics:")
output.append(f"Correct Characters: {metrics['hits']}")
output.append(f"Substitutions: {metrics['substitutions']}")
output.append(f"Deletions: {metrics['deletions']}")
output.append(f"Insertions: {metrics['insertions']}")
# Calculate and print percentage stats
total_errors = (
metrics["substitutions"] + metrics["deletions"] + metrics["insertions"]
)
total_chars = metrics["total_ref_chars"]
output.append(f"\nError Analysis:")
output.append(f"Total Errors: {total_errors}")
output.append(f"Substitution Rate: {metrics['substitutions']/total_chars:.3f}")
output.append(f"Deletion Rate: {metrics['deletions']/total_chars:.3f}")
output.append(f"Insertion Rate: {metrics['insertions']/total_chars:.3f}")
return "\n".join(output)
if __name__ == "__main__":
args = parse_arguments()
try:
# Read the original texts
reference_text = read_srt_text(args.reference)
hypothesis_text = read_srt_text(args.input)
# Calculate metrics
metrics_no_punct, metrics_with_punct = calculate_cer_both_versions(
args.reference, args.input
)
# Generate and save HTML report if requested
if args.html:
html_content = generate_html_report(
reference_text, hypothesis_text, metrics_no_punct, metrics_with_punct
)
with open(args.html, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"\nHTML visualization has been saved to: {args.html}")
# Original metrics output
output_text = []
output_text.append(
format_metrics(metrics_no_punct, "Metrics Without Punctuation")
)
output_text.append(
format_metrics(metrics_with_punct, "Metrics With Punctuation")
)
output_text.append("\n=== Comparison ===")
output_text.append(f"CER without punctuation: {metrics_no_punct['wer']:.3f}")
output_text.append(f"CER with punctuation: {metrics_with_punct['wer']:.3f}")
output_text.append(
f"Difference: {abs(metrics_with_punct['wer'] - metrics_no_punct['wer']):.3f}"
)
final_output = "\n".join(output_text)
print(final_output)
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
f.write(final_output)
print(f"\nResults have been saved to: {args.output}")
except FileNotFoundError as e:
print(f"Error: Could not find one of the input files - {str(e)}")
except Exception as e:
print(f"Error occurred: {str(e)}")
|