JiaenLiu commited on
Commit
5cce091
·
2 Parent(s): 8abf414 ee800e9

Merge branch 'evaluation' of github.com:project-kxkg/project-t into evaluation

Browse files

llm and other scores fix


Former-commit-id: 878c36401ab6a5cfef8be2bd24f58dab749c08e2

evaluation/alignment.py CHANGED
@@ -3,69 +3,87 @@ import numpy as np
3
  sys.path.append('../src')
4
  from srt_util.srt import SrtScript
5
 
 
 
 
 
 
6
  def procedure(anchor, subsec, S_arr, subidx):
7
  cache_idx = 0
8
- while subidx != cache_idx:
9
  cache_idx = subidx
10
- if subidx >= len(subsec):
 
11
  break
12
  sub = subsec[subidx]
13
  if anchor.end < sub.start:
14
  continue
 
15
  if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start:
16
- S_arr[-1] += sub.source_text
17
  subidx += 1
18
- return subidx - 1
19
 
 
 
 
 
 
 
20
  def alignment(pred_path, gt_path):
21
  pred = SrtScript.parse_from_srt_file(pred_path).segments
22
  gt = SrtScript.parse_from_srt_file(gt_path).segments
23
  pred_arr, gt_arr = [], []
24
- idx_p, idx_t = 0, 0
25
 
26
  while idx_p < len(pred) or idx_t < len(gt):
 
27
  ps = pred[idx_p] if idx_p < len(pred) else None
28
  gs = gt[idx_t] if idx_t < len(gt) else None
29
-
30
  if not ps:
31
- gt_arr.append(gs.source_text)
 
32
  pred_arr.append('')
33
  idx_t += 1
34
  continue
35
 
36
  if not gs:
37
- pred_arr.append(ps.source_text)
 
38
  gt_arr.append('')
39
  idx_p += 1
40
  continue
41
 
42
  ps_dur = ps.end - ps.start
43
  gs_dur = gs.end - gs.start
44
-
 
45
  if ps_dur <= gs_dur:
 
46
  if ps.end < gs.start:
47
- pred_arr.append(ps.source_text)
48
- gt_arr.append('')
49
- idx_t -= 1
50
  else:
51
- gt_arr.append(gs.source_text)
52
  if gs.end >= ps.start:
53
- pred_arr.append(ps.source_text)
54
  idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
55
- else:
56
  pred_arr.append('')
57
  idx_p -= 1
58
  else:
 
59
  if gs.end < ps.start:
60
- gt_arr.append(gs.source_text)
61
- pred_arr.append('')
62
- idx_p -= 1
63
  else:
64
- pred_arr.append(ps.source_text)
65
  if ps.end >= gs.start:
66
- gt_arr.append(gs.source_text)
67
  idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
68
- else:
69
  gt_arr.append('')
70
  idx_t -= 1
71
 
 
3
  sys.path.append('../src')
4
  from srt_util.srt import SrtScript
5
 
6
+
7
+ # Helper method
8
+ # Align sub anchor segment pair via greedy approach
9
+ # Input: anchor segment, SRT segments, output array of sub, index of current sub
10
+ # Output: updated index of sub
11
  def procedure(anchor, subsec, S_arr, subidx):
12
  cache_idx = 0
13
+ while subidx != cache_idx: # Terminate when alignment stablizes
14
  cache_idx = subidx
15
+ # if sub segment runs out during the loop, terminate
16
+ if subidx >= len(subsec):
17
  break
18
  sub = subsec[subidx]
19
  if anchor.end < sub.start:
20
  continue
21
+ # If next sub has a heavier overlap compartment, add to current alignment
22
  if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start:
23
+ S_arr[-1] += sub#.source_text
24
  subidx += 1
 
25
 
26
+ return subidx - 1 # Reset last invalid update from loop
27
+
28
+
29
+ # Input: path1, path2
30
+ # Output: aligned array of SRTsegment corresponding to path1 path2
31
+ # Note: Modify comment with .source_text to get output array with string only
32
  def alignment(pred_path, gt_path):
33
  pred = SrtScript.parse_from_srt_file(pred_path).segments
34
  gt = SrtScript.parse_from_srt_file(gt_path).segments
35
  pred_arr, gt_arr = [], []
36
+ idx_p, idx_t = 0, 0 # idx_p: current index of pred segment, idx_t for ground truth
37
 
38
  while idx_p < len(pred) or idx_t < len(gt):
39
+ # Check if one srt file runs out while reading
40
  ps = pred[idx_p] if idx_p < len(pred) else None
41
  gs = gt[idx_t] if idx_t < len(gt) else None
42
+
43
  if not ps:
44
+ # If ps runs out, align gs segment with filler one by one
45
+ gt_arr.append(gs)#.source_text
46
  pred_arr.append('')
47
  idx_t += 1
48
  continue
49
 
50
  if not gs:
51
+ # If gs runs out, align ps segment with filler one by one
52
+ pred_arr.append(ps)#.source_text
53
  gt_arr.append('')
54
  idx_p += 1
55
  continue
56
 
57
  ps_dur = ps.end - ps.start
58
  gs_dur = gs.end - gs.start
59
+
60
+ # Check for duration to decide anchor and sub
61
  if ps_dur <= gs_dur:
62
+ # Detect segment with no overlap
63
  if ps.end < gs.start:
64
+ pred_arr.append(ps)#.source_text
65
+ gt_arr.append('') # append filler
66
+ idx_t -= 1 # reset ground truth index
67
  else:
68
+ gt_arr.append(gs)#.source_text
69
  if gs.end >= ps.start:
70
+ pred_arr.append(ps)#.source_text
71
  idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
72
+ else: # filler pairing
73
  pred_arr.append('')
74
  idx_p -= 1
75
  else:
76
+ # same overlap checking procedure
77
  if gs.end < ps.start:
78
+ gt_arr.append(gs)#.source_text
79
+ pred_arr.append('') # filler
80
+ idx_p -= 1 # reset
81
  else:
82
+ pred_arr.append(ps)#.source_text
83
  if ps.end >= gs.start:
84
+ gt_arr.append(gs)#.source_text
85
  idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
86
+ else: # filler pairing
87
  gt_arr.append('')
88
  idx_t -= 1
89
 
evaluation/evaluation.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ from evaluation.alignment import alignment
4
+ from evaluation.scores.multi_scores import multi_scores
5
+ from src.srt_util.srt import SrtScript
6
+
7
+ class Evaluator:
8
+ def __init__(self, src_path, pred_path, gt_path, eval_path, conclusion_path):
9
+ self.src_path = src_path
10
+ self.pred_path = pred_path
11
+ self.gt_path = gt_path
12
+ self.eval_path = eval_path
13
+ self.conclusion_path = conclusion_path
14
+
15
+ def eval(self):
16
+ # Align two SRT files
17
+ aligned_srt = alignment(self.pred_path, self.gt_path)
18
+
19
+ # Parse src
20
+ src_s = [s.source_text for s in SrtScript.parse_from_srt_file(self.src_path).segments]
21
+
22
+ # Get sentence scores
23
+ scorer = multi_scores()
24
+ result_data = []
25
+ for ((prd_s, gt_s), src_s) in zip(aligned_srt, src_s):
26
+ scores_dict = scorer.get(src_s, prd_s, gt_s)
27
+ scores_dict['Prediction'] = prd_s
28
+ scores_dict['Ground Truth'] = gt_s
29
+ result_data.append(scores_dict)
30
+
31
+ eval_df = pd.DataFrame(result_data)
32
+ eval_df.to_csv(self.output_path, index=False, columns=['Prediction', 'Ground Truth', 'llm', 'bleu', 'comet'])
33
+
34
+ # Get average scores
35
+ avg_llm = eval_df['llm'].mean()
36
+ avg_bleu = eval_df['bleu'].mean()
37
+ avg_comet = eval_df['comet'].mean()
38
+
39
+ conclusion_data = {
40
+ 'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
41
+ 'Score': [avg_llm, avg_bleu, avg_comet]
42
+ }
43
+ conclusion_df = pd.DataFrame(conclusion_data)
44
+ conclusion_df.to_csv(self.conclusion_path, index=False)
45
+
46
+ if __name__ == "__main__":
47
+ parser = argparse.ArgumentParser(description='Evaluate SRT files.')
48
+ parser.add_argument('-src', default='test/short_src', help='Path to source SRT file')
49
+ parser.add_argument('-pred', default='test/short_pred', help='Path to predicted SRT file')
50
+ parser.add_argument('-gt', default='test/short_gt', help='Path to ground truth SRT file')
51
+ parser.add_argument('-eval', default='eval.csv', help='Path to output CSV file')
52
+ parser.add_argument('-conclusion', default='conclusion.csv', help='Path to conclusion CSV file')
53
+ args = parser.parse_args()
54
+
55
+ evaluator = Evaluator(args.src, args.pred, args.gt, args.eval, args.conclusion)
56
+ evaluator.eval()
57
+
src/srt_util/srt.py CHANGED
@@ -50,7 +50,10 @@ class SrtSegment(object):
50
  self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
51
  end_list = self.end_time_str.split(',')[0].split(':')
52
  self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
53
- self.translation = ""
 
 
 
54
 
55
  def merge_seg(self, seg):
56
  """
@@ -105,10 +108,16 @@ class SrtScript(object):
105
  def parse_from_srt_file(cls, path: str):
106
  with open(path, 'r', encoding="utf-8") as f:
107
  script_lines = [line.rstrip() for line in f.readlines()]
108
-
 
 
109
  segments = []
110
- for i in range(0, len(script_lines), 4):
111
- segments.append(list(script_lines[i:i + 4]))
 
 
 
 
112
 
113
  return cls(segments)
114
 
 
50
  self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
51
  end_list = self.end_time_str.split(',')[0].split(':')
52
  self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
53
+ if len(args[0]) < 5:
54
+ self.translation = ""
55
+ else:
56
+ self.translation = args[0][3]
57
 
58
  def merge_seg(self, seg):
59
  """
 
108
  def parse_from_srt_file(cls, path: str):
109
  with open(path, 'r', encoding="utf-8") as f:
110
  script_lines = [line.rstrip() for line in f.readlines()]
111
+ bilingual = False
112
+ if script_lines[2] != '' and script_lines[3] != '':
113
+ bilingual = True
114
  segments = []
115
+ if bilingual:
116
+ for i in range(0, len(script_lines), 5):
117
+ segments.append(list(script_lines[i:i + 5]))
118
+ else:
119
+ for i in range(0, len(script_lines), 4):
120
+ segments.append(list(script_lines[i:i + 4]))
121
 
122
  return cls(segments)
123