yichenl5 commited on
Commit
d518333
·
2 Parent(s): d0677a4 d26fe45

Merge pull request #67 from project-kxkg/oop-refactor

Browse files

Release V1: oop refactor

Former-commit-id: ff9e71f58ad50e037e546b0ff2371a136eef26d8

.gitignore CHANGED
@@ -10,4 +10,8 @@ test.py
10
  test.srt
11
  test.txt
12
  log_*.csv
13
- log.csv
 
 
 
 
 
10
  test.srt
11
  test.txt
12
  log_*.csv
13
+ log.csv
14
+ .chroma
15
+ *.ini
16
+ local_dump/
17
+ .pytest_cache/
configs/local_launch.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # launch config for local environment
2
+ local_dump: ./local_dump
3
+ # dictionary_path: ./domain_dict
4
+ environ: local
configs/task_config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration for each task
2
+ source_lang: EN
3
+ target_lang: ZH
4
+ field: General
5
+
6
+ # ASR config
7
+ ASR:
8
+ ASR_model: whisper
9
+ whisper_config:
10
+ whisper_model: tiny
11
+ method: stable
12
+
13
+ # pre-process module config
14
+ pre_process:
15
+ sentence_form: True
16
+ spell_check: False
17
+ term_correct: True
18
+
19
+ # Translation module config
20
+ translation:
21
+ model: gpt-4
22
+ chunk_size: 1000
23
+
24
+ # post-process module config
25
+ post_process:
26
+ check_len_and_split: True
27
+ remove_trans_punctuation: True
28
+
29
+ # output type that user receive
30
+ output_type:
31
+ subtitle: srt
32
+ video: True
33
+ bilingual: True
34
+
35
+
dict_util.py CHANGED
@@ -52,4 +52,27 @@ with open("../test.csv", "w", encoding='utf-8') as w:
52
  export_csv_dict(term_dict_sc2,w)
53
 
54
  ## for load pickle, just:
55
- # pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  export_csv_dict(term_dict_sc2,w)
53
 
54
  ## for load pickle, just:
55
+ # pickle.load(f)
56
+
57
+
58
+ def form_dict(src_dict:list, tgt_dict:list) -> dict:
59
+ final_dict = {}
60
+ for idx, value in enumerate(src_dict):
61
+ for item in value:
62
+ final_dict.update({item:tgt_dict[idx]})
63
+ return final_dict
64
+
65
+
66
+ class term_dict(dict):
67
+ def __init__(self, path, src_lang, tgt_lang) -> None:
68
+ with open(f"{path}/{src_lang}.csv", 'r', encoding="utf-8") as file:
69
+ src_dict = list(csv.reader(file, delimiter=","))
70
+ with open(f"{path}/{tgt_lang}.csv", 'r', encoding="utf-8") as file:
71
+ tgt_dict = list(csv.reader(file, delimiter="," ))
72
+ super().__init__(form_dict(src_dict, tgt_dict))
73
+
74
+
75
+ def get(self, key:str) -> str:
76
+ word = self[key][randint(0,len(self[key])-1)]
77
+ return word
78
+
domain_dict/SC2/EN.csv ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ barracks
2
+ zerg
3
+ protoss
4
+ terran
5
+ engineering bay,engin bay
6
+ forge
7
+ blink
8
+ evolution chamber
9
+ cybernetics core,cybercore
10
+ enhanced shockwaves
11
+ gravitic boosters
12
+ armory
13
+ robotics bay,robo bay
14
+ twilight council,twilight
15
+ fusion core
16
+ fleet beacon
17
+ factory
18
+ ghost academy
19
+ infestation pit
20
+ robotics facility,robo
21
+ stargate
22
+ starport
23
+ archon
24
+ smart servos
25
+ gateway
26
+ warpgate
27
+ immortal
28
+ zealot
29
+ nydus network
30
+ nydus worm
31
+ hydralisk,hydra
32
+ grooved spines
33
+ muscular augments
34
+ hydralisk den,hydra den
35
+ planetary fortress
36
+ battle cruiser
37
+ weapon refit
38
+ brood lord
39
+ broodling
40
+ greater spire
41
+ anabolic synthesis
42
+ cyclone
43
+ bunker
domain_dict/SC2/ZH.csv ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 兵营
2
+ 虫族
3
+ 神族
4
+ 人族
5
+ 工程站,BE
6
+ BF,锻炉
7
+ 闪现
8
+ 进化腔
9
+ BY,赛博核心
10
+ EMP范围
11
+ ob速度
12
+ 军械库
13
+ 机械研究所,VB
14
+ 光影议会,VC
15
+ 聚变芯体
16
+ 舰队航标
17
+ 重工厂
18
+ 幽灵军校
19
+ 感染深渊
20
+ VR,机械台
21
+ 神族VS,星门
22
+ 星港,人族VS
23
+ 白球
24
+ 变形加速
25
+ 传送门
26
+ 折跃门
27
+ 不朽
28
+ 叉叉
29
+ 虫道网络
30
+ 坑道虫
31
+ 刺蛇
32
+ 刺蛇射程
33
+ 刺蛇速度
34
+ 刺蛇塔
35
+ 大地堡,行星要塞
36
+ 大和
37
+ 大和炮
38
+ 大龙
39
+ 巢虫
40
+ 大龙塔
41
+ 大牛速度
42
+ 导弹车
43
+ 地堡
entries/__init_lib_path.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ def add_path(custom_path):
5
+ if custom_path not in sys.path: sys.path.insert(0, custom_path)
6
+
7
+ this_dir = os.path.dirname(__file__)
8
+
9
+ lib_path = os.path.join(this_dir, '..')
10
+ add_path(lib_path)
entries/app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import __init_lib_path
2
+ import gradio as gr
3
+ from src.task import Task
4
+ import logging
5
+ from yaml import Loader, Dumper, load, dump
6
+ import os
7
+ from pathlib import Path
8
+ from datetime import datetime
9
+ import shutil
10
+ from uuid import uuid4
11
+
12
+ launch_config = "./configs/local_launch.yaml"
13
+ task_config = './configs/task_config.yaml'
14
+
15
+ def init(output_type, src_lang, tgt_lang, domain):
16
+ launch_cfg = load(open(launch_config), Loader=Loader)
17
+ task_cfg = load(open(task_config), Loader=Loader)
18
+
19
+ # overwrite config file
20
+ task_cfg["source_lang"] = src_lang
21
+ task_cfg["target_lang"] = tgt_lang
22
+ task_cfg["field"] = domain
23
+
24
+ if "Video File" in output_type:
25
+ task_cfg["output_type"]["video"] = True
26
+ else:
27
+ task_cfg["output_type"]["video"] = False
28
+
29
+ if "Bilingual" in output_type:
30
+ task_cfg["output_type"]["bilingual"] = True
31
+ else:
32
+ task_cfg["output_type"]["bilingual"] = False
33
+
34
+ if ".ass output" in output_type:
35
+ task_cfg["output_type"]["subtitle"] = "ass"
36
+ else:
37
+ task_cfg["output_type"]["subtitle"] = "srt"
38
+
39
+ # initialize dir
40
+ local_dir = Path(launch_cfg['local_dump'])
41
+ if not local_dir.exists():
42
+ local_dir.mkdir(parents=False, exist_ok=False)
43
+
44
+ # get task id
45
+ task_id = str(uuid4())
46
+
47
+ # create locak dir for the task
48
+ task_dir = local_dir.joinpath(f"task_{task_id}")
49
+ task_dir.mkdir(parents=False, exist_ok=False)
50
+ task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
51
+
52
+ # logging setting
53
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
54
+ logging.basicConfig(level=logging.INFO, format=logfmt, handlers=[
55
+ logging.FileHandler(
56
+ "{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
57
+ 'w', encoding='utf-8')])
58
+ return task_id, task_dir, task_cfg
59
+
60
+ def process_input(video_file, youtube_link, src_lang, tgt_lang, domain, output_type):
61
+ task_id, task_dir, task_cfg = init(output_type, src_lang, tgt_lang, domain)
62
+ if youtube_link:
63
+ task = Task.fromYoutubeLink(youtube_link, task_id, task_dir, task_cfg)
64
+ task.run()
65
+ return task.result
66
+ elif video_file is not None:
67
+ task = Task.fromVideoFile(video_file, task_id, task_dir, task_cfg)
68
+ task.run()
69
+ return task.result
70
+ else:
71
+ return None
72
+
73
+ demo = gr.Interface(fn=process_input,
74
+ inputs=[
75
+ gr.components.Video(label="Upload a video"),
76
+ gr.components.Textbox(label="Or enter a YouTube URL"),
77
+ gr.components.Dropdown(choices=["EN", "ZH"], label="Select Source Language"),
78
+ gr.components.Dropdown(choices=["ZH", "EN"], label="Select Target Language"),
79
+ gr.components.Dropdown(choices=["General", "SC2"], label="Select Domain"),
80
+ gr.CheckboxGroup(["Video File", "Bilingual", ".ass output"], label="Output Settings", info="What do you want?"),
81
+ ],
82
+ outputs=[
83
+ gr.components.Video(label="Processed Video")
84
+ ],
85
+ title="ViDove: video translation toolkit demo",
86
+ description="Upload a video or enter a YouTube URL."
87
+ )
88
+
89
+ if __name__ == "__main__":
90
+ demo.launch()
entries/run.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import __init_lib_path
2
+ import logging
3
+ from yaml import Loader, Dumper, load, dump
4
+ from src.task import Task
5
+ import openai
6
+ import argparse
7
+ import os
8
+ from pathlib import Path
9
+ from datetime import datetime
10
+ import shutil
11
+ from uuid import uuid4
12
+
13
+ """
14
+ Main entry for terminal environment.
15
+ Use it for debug and development purpose.
16
+ Usage: python3 entries/run.py [-h] [--link LINK] [--video_file VIDEO_FILE] [--audio_file AUDIO_FILE] [--srt_file SRT_FILE] [--continue CONTINUE]
17
+ [--launch_cfg LAUNCH_CFG] [--task_cfg TASK_CFG]
18
+ """
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
23
+ parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
24
+ parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
25
+ parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str, required=False)
26
+ parser.add_argument("--continue", help="task_id that need to continue", default=None, type=str, required=False) # need implement
27
+ parser.add_argument("--launch_cfg", help="launch config path", default='./configs/local_launch.yaml', type=str, required=False)
28
+ parser.add_argument("--task_cfg", help="task config path", default='./configs/task_config.yaml', type=str, required=False)
29
+ args = parser.parse_args()
30
+
31
+ return args
32
+
33
+ if __name__ == "__main__":
34
+ # read args and configs
35
+ args = parse_args()
36
+ launch_cfg = load(open(args.launch_cfg), Loader=Loader)
37
+ task_cfg = load(open(args.task_cfg), Loader=Loader)
38
+
39
+ # initialize dir
40
+ local_dir = Path(launch_cfg['local_dump'])
41
+ if not local_dir.exists():
42
+ local_dir.mkdir(parents=False, exist_ok=False)
43
+
44
+ # get task id
45
+ task_id = str(uuid4())
46
+
47
+ # create locak dir for the task
48
+ task_dir = local_dir.joinpath(f"task_{task_id}")
49
+ task_dir.mkdir(parents=False, exist_ok=False)
50
+ task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
51
+
52
+ # logging setting
53
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
54
+ logging.basicConfig(level=logging.INFO, format=logfmt, handlers=[
55
+ logging.FileHandler(
56
+ "{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
57
+ 'w', encoding='utf-8')])
58
+
59
+ # Task create
60
+ if args.link is not None:
61
+ try:
62
+ task = Task.fromYoutubeLink(args.link, task_id, task_dir, task_cfg)
63
+ except:
64
+ shutil.rmtree(task_dir)
65
+ raise RuntimeError("failed to create task from youtube link")
66
+ elif args.video_file is not None:
67
+ try:
68
+ task = Task.fromVideoFile(args.video_file, task_id, task_dir, task_cfg)
69
+ except:
70
+ shutil.rmtree(task_dir)
71
+ raise RuntimeError("failed to create task from youtube link")
72
+ elif args.audio_file is not None:
73
+ try:
74
+ task = Task.fromVideoFile(args.audio_file, task_id, task_dir, task_cfg)
75
+ except:
76
+ shutil.rmtree(task_dir)
77
+ raise RuntimeError("failed to create task from youtube link")
78
+
79
+ # add task to the status queue
80
+ task.run()
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
entries/web_backend.py ADDED
File without changes
requirement.txt CHANGED
@@ -5,6 +5,7 @@ attrs==22.2.0
5
  certifi==2022.12.7
6
  charset-normalizer==3.1.0
7
  ffmpeg-python==0.2.0
 
8
  filelock==3.10.0
9
  frozenlist==1.3.3
10
  future==0.18.3
@@ -23,12 +24,13 @@ openai-whisper @ git+https://github.com/openai/whisper.git@6dea21fd7f7253bfe450f
23
  panda==0.3.1
24
  pandas==1.5.3
25
  python-dateutil==2.8.2
26
- pytube==12.1.2
27
  pytube3==9.6.4
28
  pytz==2022.7.1
29
  regex==2022.10.31
30
  requests==2.28.2
31
  six==1.16.0
 
32
  sympy==1.11.1
33
  tiktoken==0.3.1
34
  torch==2.0.0
 
5
  certifi==2022.12.7
6
  charset-normalizer==3.1.0
7
  ffmpeg-python==0.2.0
8
+ Flask==2.3.3
9
  filelock==3.10.0
10
  frozenlist==1.3.3
11
  future==0.18.3
 
24
  panda==0.3.1
25
  pandas==1.5.3
26
  python-dateutil==2.8.2
27
+ pytube==15.0.0
28
  pytube3==9.6.4
29
  pytz==2022.7.1
30
  regex==2022.10.31
31
  requests==2.28.2
32
  six==1.16.0
33
+ stable-ts==2.9.0
34
  sympy==1.11.1
35
  tiktoken==0.3.1
36
  torch==2.0.0
src/Pigeon.py CHANGED
@@ -317,7 +317,7 @@ class Pigeon(object):
317
  logging.info("--------------------Start Preprocessing SRT class--------------------")
318
  self.srt.write_srt_file_src(self.srt_path)
319
  self.srt.form_whole_sentence()
320
- self.srt.spell_check_term()
321
  self.srt.correct_with_force_term()
322
  processed_srt_file_en = str(Path(self.srt_path).with_suffix('')) + '_processed.srt'
323
  self.srt.write_srt_file_src(processed_srt_file_en)
 
317
  logging.info("--------------------Start Preprocessing SRT class--------------------")
318
  self.srt.write_srt_file_src(self.srt_path)
319
  self.srt.form_whole_sentence()
320
+ # self.srt.spell_check_term()
321
  self.srt.correct_with_force_term()
322
  processed_srt_file_en = str(Path(self.srt_path).with_suffix('')) + '_processed.srt'
323
  self.srt.write_srt_file_src(processed_srt_file_en)
src/preprocess/audio_extract.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import subprocess
4
+
5
+
6
+ def extract_audio(local_video_path: str, save_dir_path: str = "./downloads/audio") -> str:
7
+ if os.name == 'nt':
8
+ NotImplementedError("Filename extraction on Windows not yet implemented")
9
+
10
+ out_file_name = os.path.basename(local_video_path)
11
+ audio_path_out = save_dir_path.join("/").join(out_file_name)
12
+ subprocess.run(['ffmpeg', '-i', local_video_path, '-f', 'mp3', '-ab', '192000', '-vn', audio_path_out])
13
+ return audio_path_out
src/preprocess/video_download.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytube import YouTube
2
+ import logging
3
+
4
+ def download_youtube_to_local_file(youtube_url: str, local_dir_path: str = "./downloads") -> str:
5
+ yt = YouTube(youtube_url)
6
+ try:
7
+ audio = yt.streams.filter(only_audio=True, file_extension='mp4').order_by('abr').desc().first()
8
+ # video = yt.streams.filter(file_extension='mp4').order_by('resolution').asc().first()
9
+ if audio:
10
+ saved_audio = audio.download(output_path=local_dir_path.join("/audio"))
11
+ logging.info(f"Audio download successful: {saved_audio}")
12
+ return saved_audio
13
+ else:
14
+ logging.error(f"Audio stream not found in {youtube_url}")
15
+ raise f"Audio stream not found in {youtube_url}"
16
+ except Exception as e:
17
+ # print("Connection Error: ", end='')
18
+ print(e)
19
+ raise e
20
+
src/srt_util/srt.py CHANGED
@@ -7,10 +7,59 @@ from datetime import timedelta
7
  import logging
8
  import openai
9
  from tqdm import tqdm
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class SrtSegment(object):
13
- def __init__(self, *args) -> None:
 
 
 
14
  if isinstance(args[0], dict):
15
  segment = args[0]
16
  self.start = segment['start']
@@ -54,6 +103,7 @@ class SrtSegment(object):
54
  self.translation = ""
55
  else:
56
  self.translation = args[0][3]
 
57
 
58
  def merge_seg(self, seg):
59
  """
@@ -83,12 +133,14 @@ class SrtSegment(object):
83
 
84
  def remove_trans_punc(self) -> None:
85
  """
86
- remove CN punctuations in translation text
87
  :return: None
88
  """
89
- punc_cn = ",。!?"
90
- translator = str.maketrans(punc_cn, ' ' * len(punc_cn))
91
- self.translation = self.translation.translate(translator)
 
 
92
 
93
  def __str__(self) -> str:
94
  return f'{self.duration}\n{self.source_text}\n\n'
@@ -101,11 +153,25 @@ class SrtSegment(object):
101
 
102
 
103
  class SrtScript(object):
104
- def __init__(self, segments) -> None:
105
- self.segments = [SrtSegment(seg) for seg in segments]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  @classmethod
108
- def parse_from_srt_file(cls, path: str):
109
  with open(path, 'r', encoding="utf-8") as f:
110
  script_lines = [line.rstrip() for line in f.readlines()]
111
  bilingual = False
@@ -119,7 +185,7 @@ class SrtScript(object):
119
  for i in range(0, len(script_lines), 4):
120
  segments.append(list(script_lines[i:i + 4]))
121
 
122
- return cls(segments)
123
 
124
  def merge_segs(self, idx_list) -> SrtSegment:
125
  """
@@ -147,9 +213,10 @@ class SrtScript(object):
147
  logging.info("Forming whole sentences...")
148
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
149
  sentence = []
 
150
  # Get each entire sentence of distinct segments, fill indices to merge_list
151
  for i, seg in enumerate(self.segments):
152
- if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
153
  sentence.append(i)
154
  merge_list.append(sentence)
155
  sentence = []
@@ -184,19 +251,20 @@ class SrtScript(object):
184
  src_text += '\n\n'
185
 
186
  def inner_func(target, input_str):
 
187
  response = openai.ChatCompletion.create(
188
- # model=model,
189
  model="gpt-4",
190
  messages=[
191
  {"role": "system",
192
- "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
193
- {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
194
- {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)}
195
  ],
196
  temperature=0.15
197
  )
198
  return response['choices'][0]['message']['content'].strip()
199
 
 
200
  lines = translate.split('\n\n')
201
  if len(lines) < (end_seg_id - start_seg_id + 1):
202
  count = 0
@@ -204,28 +272,27 @@ class SrtScript(object):
204
  while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
205
  count += 1
206
  print("Solving Unmatched Lines|iteration {}".format(count))
 
207
 
208
  flag = True
209
  while flag:
210
  flag = False
211
- # print("translate:")
212
- # print(translate)
213
  try:
214
- # print("target")
215
- # print(end_seg_id - start_seg_id + 1)
216
  translate = inner_func(end_seg_id - start_seg_id + 1, translate)
217
  except Exception as e:
218
  print("An error has occurred during solving unmatched lines:", e)
219
  print("Retrying...")
 
 
220
  flag = True
221
  lines = translate.split('\n')
222
- # print("result")
223
- # print(len(lines))
224
 
225
  if len(lines) < (end_seg_id - start_seg_id + 1):
226
  solved = False
227
  print("Failed Solving unmatched lines, Manually parse needed")
 
228
 
 
229
  if not os.path.exists("./logs"):
230
  os.mkdir("./logs")
231
  if video_link:
@@ -244,7 +311,7 @@ class SrtScript(object):
244
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
245
  log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
246
  len(self.segments)) + ',' + video_name + "\n")
247
- print(lines)
248
 
249
  for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
250
  # naive way to due with merge translation problem
@@ -262,24 +329,27 @@ class SrtScript(object):
262
 
263
  def split_seg(self, seg, text_threshold, time_threshold):
264
  # evenly split seg to 2 parts and add new seg into self.segments
265
-
266
  # ignore the initial comma to solve the recursion problem
 
 
 
267
  if len(seg.source_text) > 2:
268
- if seg.source_text[:2] == ', ':
269
  seg.source_text = seg.source_text[2:]
270
- if seg.translation[0] == ',':
271
  seg.translation = seg.translation[1:]
272
 
273
  source_text = seg.source_text
274
  translation = seg.translation
275
 
276
  # split the text based on commas
277
- src_commas = [m.start() for m in re.finditer(',', source_text)]
278
- trans_commas = [m.start() for m in re.finditer(',', translation)]
279
  if len(src_commas) != 0:
280
  src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
281
  len(src_commas) // 2 - 1]
282
  else:
 
283
  src_space = [m.start() for m in re.finditer(' ', source_text)]
284
  if len(src_space) > 0:
285
  src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
@@ -315,14 +385,14 @@ class SrtScript(object):
315
  seg1_dict['text'] = src_seg1
316
  seg1_dict['start'] = start_seg1
317
  seg1_dict['end'] = end_seg1
318
- seg1 = SrtSegment(seg1_dict)
319
  seg1.translation = trans_seg1
320
 
321
  seg2_dict = {}
322
  seg2_dict['text'] = src_seg2
323
  seg2_dict['start'] = start_seg2
324
  seg2_dict['end'] = end_seg2
325
- seg2 = SrtSegment(seg2_dict)
326
  seg2.translation = trans_seg2
327
 
328
  result_list = []
@@ -353,8 +423,6 @@ class SrtScript(object):
353
  self.segments = segments
354
  logging.info("check_len_and_split finished")
355
 
356
- pass
357
-
358
  def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
359
  # DEPRECATED
360
  # if sentence length >= text_threshold, split this segments to two
@@ -376,22 +444,24 @@ class SrtScript(object):
376
  def correct_with_force_term(self):
377
  ## force term correction
378
  logging.info("performing force term correction")
379
- # load term dictionary
380
- with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
381
- term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
382
 
383
- keywords = list(term_enzh_dict.keys())
384
- keywords.sort(key=lambda x: len(x), reverse=True)
385
-
386
- for word in keywords:
387
- for i, seg in enumerate(self.segments):
388
- if word in seg.source_text.lower():
389
- seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)),
390
- seg.source_text, flags=re.IGNORECASE)
391
- logging.info(
392
- "replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(
393
- i + 1))
394
- logging.info("source text becomes: " + seg.source_text)
 
 
 
 
 
395
 
396
  comp_dict = []
397
 
@@ -425,6 +495,12 @@ class SrtScript(object):
425
 
426
  def spell_check_term(self):
427
  logging.info("performing spell check")
 
 
 
 
 
 
428
  import enchant
429
  dict = enchant.Dict('en_US')
430
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
@@ -532,3 +608,27 @@ class SrtScript(object):
532
  f.write(f'{i + idx}\n')
533
  f.write(seg.get_bilingual_str())
534
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import logging
8
  import openai
9
  from tqdm import tqdm
10
+ import dict_util
11
+
12
+ # punctuation dictionary for supported languages
13
+ punctuation_dict = {
14
+ "EN": {
15
+ "punc_str": ". , ? ! : ; - ( ) [ ] { }",
16
+ "comma": ", ",
17
+ "sentence_end": [".", "!", "?", ";"]
18
+ },
19
+ "ES": {
20
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿",
21
+ "comma": ", ",
22
+ "sentence_end": [".", "!", "?", ";", "¡", "¿"]
23
+ },
24
+ "FR": {
25
+ "punc_str": ".,?!:;«»—",
26
+ "comma": ", ",
27
+ "sentence_end": [".", "!", "?", ";"]
28
+ },
29
+ "DE": {
30
+ "punc_str": ".,?!:;„“–",
31
+ "comma": ", ",
32
+ "sentence_end": [".", "!", "?", ";"]
33
+ },
34
+ "RU": {
35
+ "punc_str": ".,?!:;-«»—",
36
+ "comma": ", ",
37
+ "sentence_end": [".", "!", "?", ";"]
38
+ },
39
+ "ZH": {
40
+ "punc_str": "。,?!:;()",
41
+ "comma": ",",
42
+ "sentence_end": ["。", "!", "?"]
43
+ },
44
+ "JA": {
45
+ "punc_str": "。、?!:;()",
46
+ "comma": "、",
47
+ "sentence_end": ["。", "!", "?"]
48
+ },
49
+ "AR": {
50
+ "punc_str": ".,?!:;-()[]،؛ ؟ «»",
51
+ "comma": "، ",
52
+ "sentence_end": [".", "!", "?", ";", "؟"]
53
+ },
54
+ }
55
+
56
+ dict_path = "./domain_dict"
57
 
58
  class SrtSegment(object):
59
+ def __init__(self, src_lang, tgt_lang, *args) -> None:
60
+ self.src_lang = src_lang
61
+ self.tgt_lang = tgt_lang
62
+
63
  if isinstance(args[0], dict):
64
  segment = args[0]
65
  self.start = segment['start']
 
103
  self.translation = ""
104
  else:
105
  self.translation = args[0][3]
106
+
107
 
108
  def merge_seg(self, seg):
109
  """
 
133
 
134
  def remove_trans_punc(self) -> None:
135
  """
136
+ remove punctuations in translation text
137
  :return: None
138
  """
139
+ punc_str = punctuation_dict[self.tgt_lang]["punc_str"]
140
+ for punc in punc_str:
141
+ self.translation = self.translation.replace(punc, ' ')
142
+ # translator = str.maketrans(punc, ' ' * len(punc))
143
+ # self.translation = self.translation.translate(translator)
144
 
145
  def __str__(self) -> str:
146
  return f'{self.duration}\n{self.source_text}\n\n'
 
153
 
154
 
155
  class SrtScript(object):
156
+ def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None:
157
+ self.domain = domain
158
+ self.src_lang = src_lang
159
+ self.tgt_lang = tgt_lang
160
+ self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
161
+
162
+ if self.domain != "General":
163
+ if os.path.exists(f"{dict_path}/{self.domain}") and\
164
+ os.path.exists(f"{dict_path}/{self.domain}/{src_lang}.csv") and os.path.exists(f"{dict_path}/{self.domain}/{tgt_lang}.csv" ):
165
+ # TODO: load dictionary
166
+ self.dict = dict_util.term_dict(f"{dict_path}/{self.domain}", src_lang, tgt_lang)
167
+ ...
168
+ else:
169
+ logging.error(f"domain {self.domain} or related dictionary({src_lang} or {tgt_lang}) doesn't exist, fallback to general domain, this will disable correct_with_force_term and spell_check_term")
170
+ self.domain = "General"
171
+
172
 
173
  @classmethod
174
+ def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
175
  with open(path, 'r', encoding="utf-8") as f:
176
  script_lines = [line.rstrip() for line in f.readlines()]
177
  bilingual = False
 
185
  for i in range(0, len(script_lines), 4):
186
  segments.append(list(script_lines[i:i + 4]))
187
 
188
+ return cls(src_lang, tgt_lang, segments)
189
 
190
  def merge_segs(self, idx_list) -> SrtSegment:
191
  """
 
213
  logging.info("Forming whole sentences...")
214
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
215
  sentence = []
216
+ ending_puncs = punctuation_dict[self.src_lang]["sentence_end"]
217
  # Get each entire sentence of distinct segments, fill indices to merge_list
218
  for i, seg in enumerate(self.segments):
219
+ if seg.source_text[-1] in ending_puncs and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
220
  sentence.append(i)
221
  merge_list.append(sentence)
222
  sentence = []
 
251
  src_text += '\n\n'
252
 
253
  def inner_func(target, input_str):
254
+ # handling merge sentences issue.
255
  response = openai.ChatCompletion.create(
 
256
  model="gpt-4",
257
  messages=[
258
  {"role": "system",
259
+ "content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"},
260
+ {"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)},
261
+ {"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)}
262
  ],
263
  temperature=0.15
264
  )
265
  return response['choices'][0]['message']['content'].strip()
266
 
267
+ # handling merge sentences issue.
268
  lines = translate.split('\n\n')
269
  if len(lines) < (end_seg_id - start_seg_id + 1):
270
  count = 0
 
272
  while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
273
  count += 1
274
  print("Solving Unmatched Lines|iteration {}".format(count))
275
+ logging.error("Solving Unmatched Lines|iteration {}".format(count))
276
 
277
  flag = True
278
  while flag:
279
  flag = False
 
 
280
  try:
 
 
281
  translate = inner_func(end_seg_id - start_seg_id + 1, translate)
282
  except Exception as e:
283
  print("An error has occurred during solving unmatched lines:", e)
284
  print("Retrying...")
285
+ logging.error("An error has occurred during solving unmatched lines:", e)
286
+ logging.error("Retrying...")
287
  flag = True
288
  lines = translate.split('\n')
 
 
289
 
290
  if len(lines) < (end_seg_id - start_seg_id + 1):
291
  solved = False
292
  print("Failed Solving unmatched lines, Manually parse needed")
293
+ logging.error("Failed Solving unmatched lines, Manually parse needed")
294
 
295
+ # FIXME: put the error log in our log file
296
  if not os.path.exists("./logs"):
297
  os.mkdir("./logs")
298
  if video_link:
 
311
  log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
312
  log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
313
  len(self.segments)) + ',' + video_name + "\n")
314
+ # print(lines)
315
 
316
  for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
317
  # naive way to due with merge translation problem
 
329
 
330
  def split_seg(self, seg, text_threshold, time_threshold):
331
  # evenly split seg to 2 parts and add new seg into self.segments
 
332
  # ignore the initial comma to solve the recursion problem
333
+ src_comma_str = punctuation_dict[self.src_lang]["comma"]
334
+ tgt_comma_str = punctuation_dict[self.tgt_lang]["comma"]
335
+
336
  if len(seg.source_text) > 2:
337
+ if seg.source_text[:2] == src_comma_str:
338
  seg.source_text = seg.source_text[2:]
339
+ if seg.translation[0] == tgt_comma_str:
340
  seg.translation = seg.translation[1:]
341
 
342
  source_text = seg.source_text
343
  translation = seg.translation
344
 
345
  # split the text based on commas
346
+ src_commas = [m.start() for m in re.finditer(src_comma_str, source_text)]
347
+ trans_commas = [m.start() for m in re.finditer(tgt_comma_str, translation)]
348
  if len(src_commas) != 0:
349
  src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
350
  len(src_commas) // 2 - 1]
351
  else:
352
+ # split the text based on spaces
353
  src_space = [m.start() for m in re.finditer(' ', source_text)]
354
  if len(src_space) > 0:
355
  src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
 
385
  seg1_dict['text'] = src_seg1
386
  seg1_dict['start'] = start_seg1
387
  seg1_dict['end'] = end_seg1
388
+ seg1 = SrtSegment(self.src_lang, self.tgt_lang, seg1_dict)
389
  seg1.translation = trans_seg1
390
 
391
  seg2_dict = {}
392
  seg2_dict['text'] = src_seg2
393
  seg2_dict['start'] = start_seg2
394
  seg2_dict['end'] = end_seg2
395
+ seg2 = SrtSegment(self.src_lang, self.tgt_lang, seg2_dict)
396
  seg2.translation = trans_seg2
397
 
398
  result_list = []
 
423
  self.segments = segments
424
  logging.info("check_len_and_split finished")
425
 
 
 
426
  def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
427
  # DEPRECATED
428
  # if sentence length >= text_threshold, split this segments to two
 
444
  def correct_with_force_term(self):
445
  ## force term correction
446
  logging.info("performing force term correction")
 
 
 
447
 
448
+ # check domain
449
+ if self.domain == "General":
450
+ logging.info("General domain could not perform correct_with_force_term. skip this step.")
451
+ pass
452
+ else:
453
+ keywords = list(self.dict.keys())
454
+ keywords.sort(key=lambda x: len(x), reverse=True)
455
+
456
+ for word in keywords:
457
+ for i, seg in enumerate(self.segments):
458
+ if word in seg.source_text.lower():
459
+ seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(self.dict.get(word)),
460
+ seg.source_text, flags=re.IGNORECASE)
461
+ logging.info(
462
+ "replace term: " + word + " --> " + self.dict.get(word) + " in time stamp {}".format(
463
+ i + 1))
464
+ logging.info("source text becomes: " + seg.source_text)
465
 
466
  comp_dict = []
467
 
 
495
 
496
  def spell_check_term(self):
497
  logging.info("performing spell check")
498
+
499
+ # check domain
500
+ if self.domain == "General":
501
+ logging.info("General domain could not perform spell_check_term. skip this step.")
502
+ pass
503
+
504
  import enchant
505
  dict = enchant.Dict('en_US')
506
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
 
608
  f.write(f'{i + idx}\n')
609
  f.write(seg.get_bilingual_str())
610
  pass
611
+
612
+ def split_script(script_in, chunk_size=1000):
613
+ script_split = script_in.split('\n\n')
614
+ script_arr = []
615
+ range_arr = []
616
+ start = 1
617
+ end = 0
618
+ script = ""
619
+ for sentence in script_split:
620
+ if len(script) + len(sentence) + 1 <= chunk_size:
621
+ script += sentence + '\n\n'
622
+ end += 1
623
+ else:
624
+ range_arr.append((start, end))
625
+ start = end + 1
626
+ end += 1
627
+ script_arr.append(script.strip())
628
+ script = sentence + '\n\n'
629
+ if script.strip():
630
+ script_arr.append(script.strip())
631
+ range_arr.append((start, len(script_split) - 1))
632
+
633
+ assert len(script_arr) == len(range_arr)
634
+ return script_arr, range_arr
src/task.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ import openai
5
+ from pytube import YouTube
6
+ from os import getenv, getcwd
7
+ from pathlib import Path
8
+ from enum import Enum, auto
9
+ import logging
10
+ import subprocess
11
+ from src.srt_util.srt import SrtScript
12
+ from src.srt_util.srt2ass import srt2ass
13
+ from time import time, strftime, gmtime, sleep
14
+ from src.translators.translation import get_translation, prompt_selector
15
+
16
+ import torch
17
+ import stable_whisper
18
+ import shutil
19
+
20
+ """
21
+ Youtube link
22
+ - link
23
+ - model
24
+ - output type
25
+
26
+ Video file
27
+ - path
28
+ - model
29
+ - output type
30
+
31
+ Audio file
32
+ - path
33
+ - model
34
+ - output type
35
+
36
+ """
37
+ """
38
+ TaskID
39
+ Progress: Enum
40
+ Computing resrouce status
41
+ SRT_Script : SrtScript
42
+ - input module -> initialize (ASR module)
43
+ - Pre-process
44
+ - Translation (%)
45
+ - Post process (time stamp)
46
+ - Output module: SRT_Script --> output(.srt)
47
+ - (Optional) mp4
48
+ """
49
+
50
+ class TaskStatus(str, Enum):
51
+ CREATED = 'CREATED'
52
+ INITIALIZING_ASR = 'INITIALIZING_ASR'
53
+ PRE_PROCESSING = 'PRE_PROCESSING'
54
+ TRANSLATING = 'TRANSLATING'
55
+ POST_PROCESSING = 'POST_PROCESSING'
56
+ OUTPUT_MODULE = 'OUTPUT_MODULE'
57
+
58
+
59
+ class Task:
60
+ @property
61
+ def status(self):
62
+ with self.__status_lock:
63
+ return self.__status
64
+
65
+ @status.setter
66
+ def status(self, new_status):
67
+ with self.__status_lock:
68
+ self.__status = new_status
69
+
70
+ def __init__(self, task_id, task_local_dir, task_cfg):
71
+ self.__status_lock = threading.Lock()
72
+ self.__status = TaskStatus.CREATED
73
+ self.gpu_status = 0
74
+ openai.api_key = getenv("OPENAI_API_KEY")
75
+ self.task_id = task_id
76
+
77
+ self.task_local_dir = task_local_dir
78
+ self.ASR_setting = task_cfg["ASR"]
79
+ self.translation_setting = task_cfg["translation"]
80
+ self.translation_model = self.translation_setting["model"]
81
+
82
+ self.output_type = task_cfg["output_type"]
83
+ self.target_lang = task_cfg["target_lang"]
84
+ self.source_lang = task_cfg["source_lang"]
85
+ self.field = task_cfg["field"]
86
+ self.pre_setting = task_cfg["pre_process"]
87
+ self.post_setting = task_cfg["post_process"]
88
+
89
+ self.audio_path = None
90
+ self.SRT_Script = None
91
+ self.result = None
92
+ self.s_t = None
93
+ self.t_e = None
94
+
95
+ print(f"Task ID: {self.task_id}")
96
+ logging.info(f"Task ID: {self.task_id}")
97
+ logging.info(f"{self.source_lang} -> {self.target_lang} task in {self.field}")
98
+ logging.info(f"Translation Model: {self.translation_model}")
99
+ logging.info(f"subtitle_type: {self.output_type['subtitle']}")
100
+ logging.info(f"video_ouput: {self.output_type['video']}")
101
+ logging.info(f"bilingual_ouput: {self.output_type['bilingual']}")
102
+ logging.info("Pre-process setting:")
103
+ for key in self.pre_setting:
104
+ logging.info(f"{key}: {self.pre_setting[key]}")
105
+ logging.info("Post-process setting:")
106
+ for key in self.post_setting:
107
+ logging.info(f"{key}: {self.post_setting[key]}")
108
+
109
+ @staticmethod
110
+ def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg):
111
+ # convert to audio
112
+ logging.info("Task Creation method: Youtube Link")
113
+ return YoutubeTask(task_id, task_dir, task_cfg, youtube_url)
114
+
115
+ @staticmethod
116
+ def fromAudioFile(audio_path, task_id, task_dir, task_cfg):
117
+ # get audio path
118
+ logging.info("Task Creation method: Audio File")
119
+ return AudioTask(task_id, task_dir, task_cfg, audio_path)
120
+
121
+ @staticmethod
122
+ def fromVideoFile(video_path, task_id, task_dir, task_cfg):
123
+ # get audio path
124
+ logging.info("Task Creation method: Video File")
125
+ return VideoTask(task_id, task_dir, task_cfg, video_path)
126
+
127
+ # Module 1 ASR: audio --> SRT_script
128
+ def get_srt_class(self):
129
+ # Instead of using the script_en variable directly, we'll use script_input
130
+ # TODO: setup ASR module like translator
131
+ self.status = TaskStatus.INITIALIZING_ASR
132
+ self.t_s = time()
133
+
134
+ method = self.ASR_setting["whisper_config"]["method"]
135
+ whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
136
+ src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id}_{self.source_lang}.srt")
137
+ if not Path.exists(src_srt_path):
138
+ # extract script from audio
139
+ logging.info("extract script from audio")
140
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141
+
142
+ if method == "api":
143
+ with open(self.audio_path, 'rb') as audio_file:
144
+ transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt")
145
+ elif method == "stable":
146
+ model = stable_whisper.load_model(whisper_model, device)
147
+ transcript = model.transcribe(str(self.audio_path), regroup=False,
148
+ initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
149
+ (
150
+ transcript
151
+ .split_by_punctuation(['.', '。', '?'])
152
+ .merge_by_gap(.15, max_words=3)
153
+ .merge_by_punctuation([' '])
154
+ .split_by_punctuation(['.', '。', '?'])
155
+ )
156
+ transcript = transcript.to_dict()
157
+
158
+ # after get the transcript, release the gpu resource
159
+ torch.cuda.empty_cache()
160
+
161
+ self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'], self.field)
162
+ # save the srt script to local
163
+ self.SRT_Script.write_srt_file_src(src_srt_path)
164
+
165
+ # Module 2: SRT preprocess: perform preprocess steps
166
+ def preprocess(self):
167
+ self.status = TaskStatus.PRE_PROCESSING
168
+ logging.info("--------------------Start Preprocessing SRT class--------------------")
169
+ if self.pre_setting["sentence_form"]:
170
+ self.SRT_Script.form_whole_sentence()
171
+ if self.pre_setting["spell_check"]:
172
+ self.SRT_Script.spell_check_term()
173
+ if self.pre_setting["term_correct"]:
174
+ self.SRT_Script.correct_with_force_term()
175
+ processed_srt_path_src = str(Path(self.task_local_dir) / f'{self.task_id}_processed.srt')
176
+ self.SRT_Script.write_srt_file_src(processed_srt_path_src)
177
+
178
+ if self.output_type["subtitle"] == "ass":
179
+ logging.info("write English .srt file to .ass")
180
+ assSub_src = srt2ass(processed_srt_path_src, "default", "No", "Modest")
181
+ logging.info('ASS subtitle saved as: ' + assSub_src)
182
+ self.script_input = self.SRT_Script.get_source_only()
183
+ pass
184
+
185
+ def update_translation_progress(self, new_progress):
186
+ if self.progress == TaskStatus.TRANSLATING:
187
+ self.progress = TaskStatus.TRANSLATING.value[0], new_progress
188
+
189
+ # Module 3: perform srt translation
190
+ def translation(self):
191
+ logging.info("---------------------Start Translation--------------------")
192
+ prompt = prompt_selector(self.source_lang, self.target_lang, self.field)
193
+ get_translation(self.SRT_Script, self.translation_model, self.task_id, prompt, self.translation_setting['chunk_size'])
194
+
195
+ # Module 4: perform srt post process steps
196
+ def postprocess(self):
197
+ self.status = TaskStatus.POST_PROCESSING
198
+
199
+ logging.info("---------------------Start Post-processing SRT class---------------------")
200
+ if self.post_setting["check_len_and_split"]:
201
+ self.SRT_Script.check_len_and_split()
202
+ if self.post_setting["remove_trans_punctuation"]:
203
+ self.SRT_Script.remove_trans_punctuation()
204
+ logging.info("---------------------Post-processing SRT class finished---------------------")
205
+
206
+ # Module 5: output module
207
+ def output_render(self):
208
+ self.status = TaskStatus.OUTPUT_MODULE
209
+ video_out = self.output_type["video"]
210
+ subtitle_type = self.output_type["subtitle"]
211
+ is_bilingual = self.output_type["bilingual"]
212
+
213
+ results_dir =f"{self.task_local_dir}/results"
214
+
215
+ subtitle_path = f"{results_dir}/{self.task_id}_{self.target_lang}.srt"
216
+ self.SRT_Script.write_srt_file_translate(subtitle_path)
217
+ if is_bilingual:
218
+ subtitle_path = f"{results_dir}/{self.task_id}_{self.source_lang}_{self.target_lang}.srt"
219
+ self.SRT_Script.write_srt_file_bilingual(subtitle_path)
220
+
221
+ if subtitle_type == "ass":
222
+ logging.info("write .srt file to .ass")
223
+ subtitle_path = srt2ass(subtitle_path, "default", "No", "Modest")
224
+ logging.info('ASS subtitle saved as: ' + subtitle_path)
225
+
226
+ final_res = subtitle_path
227
+
228
+ # encode to .mp4 video file
229
+ if video_out and self.video_path is not None:
230
+ logging.info("encoding video file")
231
+ logging.info(f'ffmpeg comand: \nffmpeg -i {self.video_path} -vf "subtitles={subtitle_path}" {results_dir}/{self.task_id}.mp4')
232
+ subprocess.run(
233
+ ["ffmpeg",
234
+ "-i", self.video_path,
235
+ "-vf", f"subtitles={subtitle_path}",
236
+ f"{results_dir}/{self.task_id}.mp4"])
237
+ final_res = f"{results_dir}/{self.task_id}.mp4"
238
+
239
+ self.t_e = time()
240
+ logging.info(
241
+ "Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s))))
242
+ return final_res
243
+
244
+ def run_pipeline(self):
245
+ self.get_srt_class()
246
+ self.preprocess()
247
+ self.translation()
248
+ self.postprocess()
249
+ self.result = self.output_render()
250
+ # print(self.result)
251
+
252
+ class YoutubeTask(Task):
253
+ def __init__(self, task_id, task_local_dir, task_cfg, youtube_url):
254
+ super().__init__(task_id, task_local_dir, task_cfg)
255
+ self.youtube_url = youtube_url
256
+
257
+ def run(self):
258
+ yt = YouTube(self.youtube_url)
259
+ video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
260
+
261
+ if video:
262
+ video.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp4")
263
+ logging.info(f'Video Name: {video.default_filename}')
264
+ else:
265
+ raise FileNotFoundError(f" Video stream not found for link {self.youtube_url}")
266
+
267
+ audio = yt.streams.filter(only_audio=True).first()
268
+ if audio:
269
+ audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3")
270
+ else:
271
+ logging.info(" download audio failed, using ffmpeg to extract audio")
272
+ subprocess.run(
273
+ ['ffmpeg', '-i', self.task_local_dir.joinpath(f"task_{self.task_id}.mp4"), '-f', 'mp3',
274
+ '-ab', '192000', '-vn', self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")])
275
+ logging.info("audio extraction finished")
276
+
277
+ self.video_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp4")
278
+ self.audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")
279
+
280
+ logging.info(f" Video File Dir: {self.video_path}")
281
+ logging.info(f" Audio File Dir: {self.audio_path}")
282
+ logging.info(" Data Prep Complete. Start pipeline")
283
+
284
+ super().run_pipeline()
285
+
286
+ class AudioTask(Task):
287
+ def __init__(self, task_id, task_local_dir, task_cfg, audio_path):
288
+ super().__init__(task_id, task_local_dir, task_cfg)
289
+ # TODO: check audio format
290
+ self.audio_path = audio_path
291
+ self.video_path = None
292
+
293
+ def run(self):
294
+ logging.info(f"Video File Dir: {self.video_path}")
295
+ logging.info(f"Audio File Dir: {self.audio_path}")
296
+ logging.info("Data Prep Complete. Start pipeline")
297
+ super().run_pipeline()
298
+
299
+ class VideoTask(Task):
300
+ def __init__(self, task_id, task_local_dir, task_cfg, video_path):
301
+ super().__init__(task_id, task_local_dir, task_cfg)
302
+ # TODO: check video format {.mp4}
303
+ new_video_path = f"{task_local_dir}/task_{self.task_id}.mp4"
304
+ print(new_video_path)
305
+ logging.info(f"Copy video file to: {new_video_path}")
306
+ shutil.copyfile(video_path, new_video_path)
307
+ self.video_path = new_video_path
308
+
309
+ def run(self):
310
+ logging.info("using ffmpeg to extract audio")
311
+ subprocess.run(
312
+ ['ffmpeg', '-i', self.video_path, '-f', 'mp3',
313
+ '-ab', '192000', '-vn', self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")])
314
+ logging.info("audio extraction finished")
315
+
316
+ self.audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")
317
+ logging.info(f" Video File Dir: {self.video_path}")
318
+ logging.info(f" Audio File Dir: {self.audio_path}")
319
+ logging.info("Data Prep Complete. Start pipeline")
320
+ super().run_pipeline()
src/translators/LLM_task.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import openai
3
+
4
+
5
+ def LLM_task(model_name, input, task, temp = 0.15):
6
+ """
7
+ Translates input sentence with desired LLM.
8
+
9
+ :param model_name: The name of the translation model to be used.
10
+ :param input: Sentence for translation.
11
+ :param task: Prompt.
12
+ :param temp: Model temperature.
13
+ """
14
+ if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
15
+ response = openai.ChatCompletion.create(
16
+ model=model_name,
17
+ messages=[
18
+ {"role": "system","content": task},
19
+ {"role": "user", "content": input}
20
+ ],
21
+ temperature=temp
22
+ )
23
+ return response['choices'][0]['message']['content'].strip()
24
+ # Other LLM not implemented
25
+ else:
26
+ raise NotImplementedError
src/translators/__init__.py ADDED
File without changes
src/translators/translation.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ import logging
3
+ from time import sleep
4
+ from tqdm import tqdm
5
+ from src.srt_util.srt import split_script
6
+ from .LLM_task import LLM_task
7
+
8
+ def get_translation(srt, model, video_name, prompt, chunk_size = 1000):
9
+ script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
10
+ translate(srt, script_arr, range_arr, model, video_name, task=prompt)
11
+ pass
12
+
13
+ def check_translation(sentence, translation):
14
+ """
15
+ check merge sentence issue from openai translation
16
+ """
17
+ sentence_count = sentence.count('\n\n') + 1
18
+ translation_count = translation.count('\n\n') + 1
19
+
20
+ if sentence_count != translation_count:
21
+ return False
22
+ else:
23
+ return True
24
+
25
+ # TODO{david}: prompts selector
26
+ def prompt_selector(src_lang, tgt_lang, domain):
27
+ language_map = {
28
+ "EN": "English",
29
+ "ZH": "Chinese",
30
+ }
31
+ src_lang = language_map[src_lang]
32
+ tgt_lang = language_map[tgt_lang]
33
+ prompt = f"""
34
+ you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
35
+ you will be provided with a segement in {src_lang} parsed by line, where your translation text should keep the original
36
+ meaning and the number of lines.
37
+ """
38
+ return prompt
39
+
40
+ def translate(srt, script_arr, range_arr, model_name, video_name=None, attempts_count=5, task=None, temp = 0.15):
41
+ """
42
+ Translates the given script array into another language using the chatgpt and writes to the SRT file.
43
+
44
+ This function takes a script array, a range array, a model name, a video name, and a video link as input. It iterates
45
+ through sentences and range in the script and range arrays. If the translation check fails for five times, the function
46
+ will attempt to resolve merge sentence issues and split the sentence into smaller tokens for a better translation.
47
+
48
+ :param srt: An instance of the Subtitle class representing the SRT file.
49
+ :param script_arr: A list of strings representing the original script sentences to be translated.
50
+ :param range_arr: A list of tuples representing the start and end positions of sentences in the script.
51
+ :param model_name: The name of the translation model to be used.
52
+ :param video_name: The name of the video.
53
+ :param attempts_count: Number of attemps of failures for unmatched sentences.
54
+ :param task: Prompt.
55
+ :param temp: Model temperature.
56
+ """
57
+
58
+ if input is None:
59
+ raise Exception("Warning! No Input have passed to LLM!")
60
+ if task is None:
61
+ task = "你是一个翻译助理,你的任务是翻译视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
62
+ logging.info(f"translation prompt: {task}")
63
+ previous_length = 0
64
+ for sentence, range_ in tqdm(zip(script_arr, range_arr)):
65
+ # update the range based on previous length
66
+ range_ = (range_[0] + previous_length, range_[1] + previous_length)
67
+ # using chatgpt model
68
+ print(f"now translating sentences {range_}")
69
+ logging.info(f"now translating sentences {range_}")
70
+ flag = True
71
+ while flag:
72
+ flag = False
73
+ try:
74
+ translate = LLM_task(model_name, sentence, task, temp)
75
+ # detect merge sentence issue and try to solve for five times:
76
+ while not check_translation(sentence, translate) and attempts_count > 0:
77
+ translate = LLM_task(model_name, sentence, task, temp)
78
+ attempts_count -= 1
79
+
80
+ # if failure still happen, split into smaller tokens
81
+ if attempts_count == 0:
82
+ single_sentences = sentence.split("\n\n")
83
+ logging.info("merge sentence issue found for range", range_)
84
+ translate = ""
85
+ for i, single_sentence in enumerate(single_sentences):
86
+ if i == len(single_sentences) - 1:
87
+ translate += LLM_task(model_name,sentence,task,temp)
88
+ else:
89
+ translate += LLM_task(model_name,sentence,task,temp) + "\n\n"
90
+ logging.info("solved by individually translation!")
91
+
92
+ except Exception as e:
93
+ logging.debug("An error has occurred during translation:", e)
94
+ print("An error has occurred during translation:", e)
95
+ print("Retrying... the script will continue after 30 seconds.")
96
+ sleep(30)
97
+ flag = True
98
+
99
+ srt.set_translation(translate, range_, model_name, video_name)
src/web/api_specs.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openapi: 3.0.3
2
+ info:
3
+ title: Pigeon AI
4
+ description: Pigeon AI
5
+ version: 1.0.0
6
+ servers:
7
+ - url: 'https'
8
+ paths:
9
+ /api/task:
10
+ post:
11
+ summary: Create a task
12
+ operationId: createTask
13
+ requestBody:
14
+ content:
15
+ application/json:
16
+ schema:
17
+ $ref: '#/components/schemas/youtubeLink'
18
+ responses:
19
+ '200':
20
+ description: OK
21
+ content:
22
+ application/json:
23
+ schema:
24
+ $ref: '#/components/schemas/task'
25
+ /api/task/{taskId}/status:
26
+ get:
27
+ summary: Get task status
28
+ operationId: getTask
29
+ parameters:
30
+ - name: taskId
31
+ in: path
32
+ required: true
33
+ description: task id
34
+ schema:
35
+ type: string
36
+ responses:
37
+ '200':
38
+ description: OK
39
+ content:
40
+ application/json:
41
+ schema:
42
+ $ref: '#/components/schemas/taskStatus'
43
+ '404':
44
+ description: Not Found
45
+ content:
46
+ application/json:
47
+ schema:
48
+ $ref: '#/components/schemas/error'
49
+
50
+ components:
51
+ schemas:
52
+ youtubeLink:
53
+ type: object
54
+ properties:
55
+ youtubeLink:
56
+ type: string
57
+ description: youtube link
58
+ example: https://www.youtube.com/watch?v=5qap5aO4i9A
59
+ task:
60
+ type: object
61
+ properties:
62
+ taskId:
63
+ type: string
64
+ description: task id generated by uuid
65
+ example: 7a765280-1a72-47e4-8747-8a38cdbaca91
66
+ taskStatus:
67
+ type: object
68
+ properties:
69
+ status:
70
+ type: string
71
+ description: task status
72
+ example: PROCESSING
73
+ error:
74
+ type: object
75
+ properties:
76
+ error:
77
+ type: string
78
+ description: error message
79
+ example: 'Invalid youtube link'
src/web/web.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from flask import Flask, request, jsonify
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from src.task import Task
5
+ from uuid import uuid4
6
+
7
+ app = Flask(__name__)
8
+
9
+ # Global thread pool
10
+ executor = ThreadPoolExecutor(max_workers=4) # Adjust max_workers as per your requirement
11
+
12
+ # thread safe task map to store task status
13
+ task_map = {}
14
+
15
+ @app.route('/api/task', methods=['POST'])
16
+ def create_task_youtube():
17
+ global task_map
18
+ data = request.get_json()
19
+ if not data or 'youtubeLink' not in data:
20
+ return jsonify({'error': 'YouTube link not provided'}), 400
21
+ youtube_link = data['youtubeLink']
22
+ launch_config = yaml.load(open("./configs/local_launch.yaml"), Loader=yaml.Loader)
23
+ task_id = str(uuid4())
24
+ task = Task.fromYoutubeLink(youtube_link, task_id, launch_config)
25
+ task_map[task_id] = task
26
+ # Submit task to thread pool
27
+ executor.submit(task.run)
28
+
29
+ return jsonify({'taskId': task.task_id})
30
+
31
+ @app.route('/api/task/<taskId>/status', methods=['GET'])
32
+ def get_task_status(taskId):
33
+ global task_map
34
+ if taskId not in task_map:
35
+ return jsonify({'error': 'Task not found'}), 404
36
+ return jsonify({'status': task_map[taskId].status})
37
+
38
+ if __name__ == '__main__':
39
+ app.run(debug=True)
tests/test_remove_punc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./src')
3
+ from srt_util.srt import SrtScript, SrtSegment
4
+
5
+ zh_test1 = "再次,如果你对一些福利感兴趣,你也可以。"
6
+ zh_en_test1 = "GG。Classic在我今年解说的最奇葩的系列赛中获得了胜利。"
7
+
8
+ def form_srt_class(src_lang, tgt_lang, source_text="", translation="", duration="00:00:00,740 --> 00:00:08,779"):
9
+ segment = [0, duration, source_text, translation, ""]
10
+ return SrtScript(src_lang, tgt_lang, [segment])
11
+
12
+ def test_zh():
13
+ srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_test1)
14
+ srt.remove_trans_punctuation()
15
+ assert srt.segments[0].translation == "再次 如果你对一些福利感兴趣 你也可以 "
16
+
17
+ def test_zh_en():
18
+ srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_en_test1)
19
+ srt.remove_trans_punctuation()
20
+ assert srt.segments[0].translation == "GG Classic在我今年解说的最奇葩的系列赛中获得了胜利 "
21
+