File size: 6,055 Bytes
a233921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import argparse
from typing import Dict, Tuple

from visualize import generate_html_report
from utils import read_srt_text, preprocess_chinese_text
import jiwer

# Expose read_srt_text function
from utils import read_srt_text

def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Calculate Character Error Rate (CER) for Chinese SRT files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "-r",
        "--reference",
        required=True,
        help="Path to the reference (golden) SRT file",
    )

    parser.add_argument(
        "-i", "--input", required=True, help="Path to the input (hypothesis) SRT file"
    )

    parser.add_argument("-o", "--output", help="Path to save the results (optional)")

    parser.add_argument("--html", help="Path to save the HTML visualization (optional)")

    return parser.parse_args()

def calculate_cer_both_versions(
    reference_path: str, hypothesis_path: str
) -> Tuple[Dict, Dict]:
    """
    Calculate CER and related metrics between reference and hypothesis SRT files,
    both with and without punctuation.

    Args:
        reference_path (str): Path to the reference SRT file
        hypothesis_path (str): Path to the hypothesis SRT file

    Returns:
        Tuple[Dict, Dict]: Two dictionaries containing metrics (with and without punctuation)
    """
    # Read files
    reference_text = read_srt_text(reference_path)
    hypothesis_text = read_srt_text(hypothesis_path)

    # Calculate metrics without punctuation
    reference_chars_no_punct = preprocess_chinese_text(
        reference_text, include_punctuation=False
    )
    hypothesis_chars_no_punct = preprocess_chinese_text(
        hypothesis_text, include_punctuation=False
    )

    metrics_no_punct = jiwer.compute_measures(
        reference_chars_no_punct, hypothesis_chars_no_punct
    )

    # Calculate metrics with punctuation
    reference_chars_with_punct = preprocess_chinese_text(
        reference_text, include_punctuation=True
    )
    hypothesis_chars_with_punct = preprocess_chinese_text(
        hypothesis_text, include_punctuation=True
    )

    metrics_with_punct = jiwer.compute_measures(
        reference_chars_with_punct, hypothesis_chars_with_punct
    )

    # Add character counts
    metrics_no_punct["total_ref_chars"] = len(reference_chars_no_punct.replace(" ", ""))
    metrics_no_punct["total_hyp_chars"] = len(
        hypothesis_chars_no_punct.replace(" ", "")
    )
    metrics_with_punct["total_ref_chars"] = len(
        reference_chars_with_punct.replace(" ", "")
    )
    metrics_with_punct["total_hyp_chars"] = len(
        hypothesis_chars_with_punct.replace(" ", "")
    )

    return metrics_no_punct, metrics_with_punct


def format_metrics(metrics: dict, version: str) -> str:
    """
    Format metrics into a string.

    Args:
        metrics (dict): Dictionary of metric values
        version (str): String indicating which version of metrics these are

    Returns:
        str: Formatted metrics string
    """
    output = []
    output.append(f"\n=== {version} ===")
    output.append(f"Character Error Rate (CER): {metrics['wer']:.3f}")
    output.append(f"Total Reference Characters: {metrics['total_ref_chars']}")
    output.append(f"Total Hypothesis Characters: {metrics['total_hyp_chars']}")

    output.append("\nDetailed Statistics:")
    output.append(f"Correct Characters: {metrics['hits']}")
    output.append(f"Substitutions: {metrics['substitutions']}")
    output.append(f"Deletions: {metrics['deletions']}")
    output.append(f"Insertions: {metrics['insertions']}")

    # Calculate and print percentage stats
    total_errors = (
        metrics["substitutions"] + metrics["deletions"] + metrics["insertions"]
    )
    total_chars = metrics["total_ref_chars"]

    output.append(f"\nError Analysis:")
    output.append(f"Total Errors: {total_errors}")
    output.append(f"Substitution Rate: {metrics['substitutions']/total_chars:.3f}")
    output.append(f"Deletion Rate: {metrics['deletions']/total_chars:.3f}")
    output.append(f"Insertion Rate: {metrics['insertions']/total_chars:.3f}")

    return "\n".join(output)



if __name__ == "__main__":
    args = parse_arguments()

    try:
        # Read the original texts
        reference_text = read_srt_text(args.reference)
        hypothesis_text = read_srt_text(args.input)

        # Calculate metrics
        metrics_no_punct, metrics_with_punct = calculate_cer_both_versions(
            args.reference, args.input
        )

        # Generate and save HTML report if requested
        if args.html:
            html_content = generate_html_report(
                reference_text, hypothesis_text, metrics_no_punct, metrics_with_punct
            )
            with open(args.html, "w", encoding="utf-8") as f:
                f.write(html_content)
            print(f"\nHTML visualization has been saved to: {args.html}")

        # Original metrics output
        output_text = []
        output_text.append(
            format_metrics(metrics_no_punct, "Metrics Without Punctuation")
        )
        output_text.append(
            format_metrics(metrics_with_punct, "Metrics With Punctuation")
        )
        output_text.append("\n=== Comparison ===")
        output_text.append(f"CER without punctuation: {metrics_no_punct['wer']:.3f}")
        output_text.append(f"CER with punctuation: {metrics_with_punct['wer']:.3f}")
        output_text.append(
            f"Difference: {abs(metrics_with_punct['wer'] - metrics_no_punct['wer']):.3f}"
        )

        final_output = "\n".join(output_text)
        print(final_output)

        if args.output:
            with open(args.output, "w", encoding="utf-8") as f:
                f.write(final_output)
            print(f"\nResults have been saved to: {args.output}")

    except FileNotFoundError as e:
        print(f"Error: Could not find one of the input files - {str(e)}")
    except Exception as e:
        print(f"Error occurred: {str(e)}")