Spaces:
Sleeping
Sleeping
Merge branch 'evaluation' of github.com:project-kxkg/project-t into evaluation
Browse filesllm and other scores fix
Former-commit-id: 878c36401ab6a5cfef8be2bd24f58dab749c08e2
- evaluation/alignment.py +39 -21
- evaluation/evaluation.py +57 -0
- src/srt_util/srt.py +13 -4
evaluation/alignment.py
CHANGED
@@ -3,69 +3,87 @@ import numpy as np
|
|
3 |
sys.path.append('../src')
|
4 |
from srt_util.srt import SrtScript
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
def procedure(anchor, subsec, S_arr, subidx):
|
7 |
cache_idx = 0
|
8 |
-
while subidx != cache_idx:
|
9 |
cache_idx = subidx
|
10 |
-
if
|
|
|
11 |
break
|
12 |
sub = subsec[subidx]
|
13 |
if anchor.end < sub.start:
|
14 |
continue
|
|
|
15 |
if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start:
|
16 |
-
S_arr[-1] += sub
|
17 |
subidx += 1
|
18 |
-
return subidx - 1
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def alignment(pred_path, gt_path):
|
21 |
pred = SrtScript.parse_from_srt_file(pred_path).segments
|
22 |
gt = SrtScript.parse_from_srt_file(gt_path).segments
|
23 |
pred_arr, gt_arr = [], []
|
24 |
-
idx_p, idx_t = 0, 0
|
25 |
|
26 |
while idx_p < len(pred) or idx_t < len(gt):
|
|
|
27 |
ps = pred[idx_p] if idx_p < len(pred) else None
|
28 |
gs = gt[idx_t] if idx_t < len(gt) else None
|
29 |
-
|
30 |
if not ps:
|
31 |
-
|
|
|
32 |
pred_arr.append('')
|
33 |
idx_t += 1
|
34 |
continue
|
35 |
|
36 |
if not gs:
|
37 |
-
|
|
|
38 |
gt_arr.append('')
|
39 |
idx_p += 1
|
40 |
continue
|
41 |
|
42 |
ps_dur = ps.end - ps.start
|
43 |
gs_dur = gs.end - gs.start
|
44 |
-
|
|
|
45 |
if ps_dur <= gs_dur:
|
|
|
46 |
if ps.end < gs.start:
|
47 |
-
pred_arr.append(ps
|
48 |
-
gt_arr.append('')
|
49 |
-
idx_t -= 1
|
50 |
else:
|
51 |
-
gt_arr.append(gs
|
52 |
if gs.end >= ps.start:
|
53 |
-
pred_arr.append(ps
|
54 |
idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
|
55 |
-
else:
|
56 |
pred_arr.append('')
|
57 |
idx_p -= 1
|
58 |
else:
|
|
|
59 |
if gs.end < ps.start:
|
60 |
-
gt_arr.append(gs
|
61 |
-
pred_arr.append('')
|
62 |
-
idx_p -= 1
|
63 |
else:
|
64 |
-
pred_arr.append(ps
|
65 |
if ps.end >= gs.start:
|
66 |
-
gt_arr.append(gs
|
67 |
idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
|
68 |
-
else:
|
69 |
gt_arr.append('')
|
70 |
idx_t -= 1
|
71 |
|
|
|
3 |
sys.path.append('../src')
|
4 |
from srt_util.srt import SrtScript
|
5 |
|
6 |
+
|
7 |
+
# Helper method
|
8 |
+
# Align sub anchor segment pair via greedy approach
|
9 |
+
# Input: anchor segment, SRT segments, output array of sub, index of current sub
|
10 |
+
# Output: updated index of sub
|
11 |
def procedure(anchor, subsec, S_arr, subidx):
|
12 |
cache_idx = 0
|
13 |
+
while subidx != cache_idx: # Terminate when alignment stablizes
|
14 |
cache_idx = subidx
|
15 |
+
# if sub segment runs out during the loop, terminate
|
16 |
+
if subidx >= len(subsec):
|
17 |
break
|
18 |
sub = subsec[subidx]
|
19 |
if anchor.end < sub.start:
|
20 |
continue
|
21 |
+
# If next sub has a heavier overlap compartment, add to current alignment
|
22 |
if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start:
|
23 |
+
S_arr[-1] += sub#.source_text
|
24 |
subidx += 1
|
|
|
25 |
|
26 |
+
return subidx - 1 # Reset last invalid update from loop
|
27 |
+
|
28 |
+
|
29 |
+
# Input: path1, path2
|
30 |
+
# Output: aligned array of SRTsegment corresponding to path1 path2
|
31 |
+
# Note: Modify comment with .source_text to get output array with string only
|
32 |
def alignment(pred_path, gt_path):
|
33 |
pred = SrtScript.parse_from_srt_file(pred_path).segments
|
34 |
gt = SrtScript.parse_from_srt_file(gt_path).segments
|
35 |
pred_arr, gt_arr = [], []
|
36 |
+
idx_p, idx_t = 0, 0 # idx_p: current index of pred segment, idx_t for ground truth
|
37 |
|
38 |
while idx_p < len(pred) or idx_t < len(gt):
|
39 |
+
# Check if one srt file runs out while reading
|
40 |
ps = pred[idx_p] if idx_p < len(pred) else None
|
41 |
gs = gt[idx_t] if idx_t < len(gt) else None
|
42 |
+
|
43 |
if not ps:
|
44 |
+
# If ps runs out, align gs segment with filler one by one
|
45 |
+
gt_arr.append(gs)#.source_text
|
46 |
pred_arr.append('')
|
47 |
idx_t += 1
|
48 |
continue
|
49 |
|
50 |
if not gs:
|
51 |
+
# If gs runs out, align ps segment with filler one by one
|
52 |
+
pred_arr.append(ps)#.source_text
|
53 |
gt_arr.append('')
|
54 |
idx_p += 1
|
55 |
continue
|
56 |
|
57 |
ps_dur = ps.end - ps.start
|
58 |
gs_dur = gs.end - gs.start
|
59 |
+
|
60 |
+
# Check for duration to decide anchor and sub
|
61 |
if ps_dur <= gs_dur:
|
62 |
+
# Detect segment with no overlap
|
63 |
if ps.end < gs.start:
|
64 |
+
pred_arr.append(ps)#.source_text
|
65 |
+
gt_arr.append('') # append filler
|
66 |
+
idx_t -= 1 # reset ground truth index
|
67 |
else:
|
68 |
+
gt_arr.append(gs)#.source_text
|
69 |
if gs.end >= ps.start:
|
70 |
+
pred_arr.append(ps)#.source_text
|
71 |
idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
|
72 |
+
else: # filler pairing
|
73 |
pred_arr.append('')
|
74 |
idx_p -= 1
|
75 |
else:
|
76 |
+
# same overlap checking procedure
|
77 |
if gs.end < ps.start:
|
78 |
+
gt_arr.append(gs)#.source_text
|
79 |
+
pred_arr.append('') # filler
|
80 |
+
idx_p -= 1 # reset
|
81 |
else:
|
82 |
+
pred_arr.append(ps)#.source_text
|
83 |
if ps.end >= gs.start:
|
84 |
+
gt_arr.append(gs)#.source_text
|
85 |
idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
|
86 |
+
else: # filler pairing
|
87 |
gt_arr.append('')
|
88 |
idx_t -= 1
|
89 |
|
evaluation/evaluation.py
CHANGED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pandas as pd
|
3 |
+
from evaluation.alignment import alignment
|
4 |
+
from evaluation.scores.multi_scores import multi_scores
|
5 |
+
from src.srt_util.srt import SrtScript
|
6 |
+
|
7 |
+
class Evaluator:
|
8 |
+
def __init__(self, src_path, pred_path, gt_path, eval_path, conclusion_path):
|
9 |
+
self.src_path = src_path
|
10 |
+
self.pred_path = pred_path
|
11 |
+
self.gt_path = gt_path
|
12 |
+
self.eval_path = eval_path
|
13 |
+
self.conclusion_path = conclusion_path
|
14 |
+
|
15 |
+
def eval(self):
|
16 |
+
# Align two SRT files
|
17 |
+
aligned_srt = alignment(self.pred_path, self.gt_path)
|
18 |
+
|
19 |
+
# Parse src
|
20 |
+
src_s = [s.source_text for s in SrtScript.parse_from_srt_file(self.src_path).segments]
|
21 |
+
|
22 |
+
# Get sentence scores
|
23 |
+
scorer = multi_scores()
|
24 |
+
result_data = []
|
25 |
+
for ((prd_s, gt_s), src_s) in zip(aligned_srt, src_s):
|
26 |
+
scores_dict = scorer.get(src_s, prd_s, gt_s)
|
27 |
+
scores_dict['Prediction'] = prd_s
|
28 |
+
scores_dict['Ground Truth'] = gt_s
|
29 |
+
result_data.append(scores_dict)
|
30 |
+
|
31 |
+
eval_df = pd.DataFrame(result_data)
|
32 |
+
eval_df.to_csv(self.output_path, index=False, columns=['Prediction', 'Ground Truth', 'llm', 'bleu', 'comet'])
|
33 |
+
|
34 |
+
# Get average scores
|
35 |
+
avg_llm = eval_df['llm'].mean()
|
36 |
+
avg_bleu = eval_df['bleu'].mean()
|
37 |
+
avg_comet = eval_df['comet'].mean()
|
38 |
+
|
39 |
+
conclusion_data = {
|
40 |
+
'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
|
41 |
+
'Score': [avg_llm, avg_bleu, avg_comet]
|
42 |
+
}
|
43 |
+
conclusion_df = pd.DataFrame(conclusion_data)
|
44 |
+
conclusion_df.to_csv(self.conclusion_path, index=False)
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
parser = argparse.ArgumentParser(description='Evaluate SRT files.')
|
48 |
+
parser.add_argument('-src', default='test/short_src', help='Path to source SRT file')
|
49 |
+
parser.add_argument('-pred', default='test/short_pred', help='Path to predicted SRT file')
|
50 |
+
parser.add_argument('-gt', default='test/short_gt', help='Path to ground truth SRT file')
|
51 |
+
parser.add_argument('-eval', default='eval.csv', help='Path to output CSV file')
|
52 |
+
parser.add_argument('-conclusion', default='conclusion.csv', help='Path to conclusion CSV file')
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
evaluator = Evaluator(args.src, args.pred, args.gt, args.eval, args.conclusion)
|
56 |
+
evaluator.eval()
|
57 |
+
|
src/srt_util/srt.py
CHANGED
@@ -50,7 +50,10 @@ class SrtSegment(object):
|
|
50 |
self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
|
51 |
end_list = self.end_time_str.split(',')[0].split(':')
|
52 |
self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
|
53 |
-
|
|
|
|
|
|
|
54 |
|
55 |
def merge_seg(self, seg):
|
56 |
"""
|
@@ -105,10 +108,16 @@ class SrtScript(object):
|
|
105 |
def parse_from_srt_file(cls, path: str):
|
106 |
with open(path, 'r', encoding="utf-8") as f:
|
107 |
script_lines = [line.rstrip() for line in f.readlines()]
|
108 |
-
|
|
|
|
|
109 |
segments = []
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
112 |
|
113 |
return cls(segments)
|
114 |
|
|
|
50 |
self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
|
51 |
end_list = self.end_time_str.split(',')[0].split(':')
|
52 |
self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
|
53 |
+
if len(args[0]) < 5:
|
54 |
+
self.translation = ""
|
55 |
+
else:
|
56 |
+
self.translation = args[0][3]
|
57 |
|
58 |
def merge_seg(self, seg):
|
59 |
"""
|
|
|
108 |
def parse_from_srt_file(cls, path: str):
|
109 |
with open(path, 'r', encoding="utf-8") as f:
|
110 |
script_lines = [line.rstrip() for line in f.readlines()]
|
111 |
+
bilingual = False
|
112 |
+
if script_lines[2] != '' and script_lines[3] != '':
|
113 |
+
bilingual = True
|
114 |
segments = []
|
115 |
+
if bilingual:
|
116 |
+
for i in range(0, len(script_lines), 5):
|
117 |
+
segments.append(list(script_lines[i:i + 5]))
|
118 |
+
else:
|
119 |
+
for i in range(0, len(script_lines), 4):
|
120 |
+
segments.append(list(script_lines[i:i + 4]))
|
121 |
|
122 |
return cls(segments)
|
123 |
|