Spaces:
Sleeping
Sleeping
Merge pull request #67 from project-kxkg/oop-refactor
Browse filesRelease V1: oop refactor
Former-commit-id: ff9e71f58ad50e037e546b0ff2371a136eef26d8
- .gitignore +5 -1
- configs/local_launch.yaml +4 -0
- configs/task_config.yaml +35 -0
- dict_util.py +24 -1
- domain_dict/SC2/EN.csv +43 -0
- domain_dict/SC2/ZH.csv +43 -0
- entries/__init_lib_path.py +10 -0
- entries/app.py +90 -0
- entries/run.py +90 -0
- entries/web_backend.py +0 -0
- requirement.txt +3 -1
- src/Pigeon.py +1 -1
- src/preprocess/audio_extract.py +13 -0
- src/preprocess/video_download.py +20 -0
- src/srt_util/srt.py +146 -46
- src/task.py +320 -0
- src/translators/LLM_task.py +26 -0
- src/translators/__init__.py +0 -0
- src/translators/translation.py +99 -0
- src/web/api_specs.yaml +79 -0
- src/web/web.py +39 -0
- tests/test_remove_punc.py +21 -0
.gitignore
CHANGED
@@ -10,4 +10,8 @@ test.py
|
|
10 |
test.srt
|
11 |
test.txt
|
12 |
log_*.csv
|
13 |
-
log.csv
|
|
|
|
|
|
|
|
|
|
10 |
test.srt
|
11 |
test.txt
|
12 |
log_*.csv
|
13 |
+
log.csv
|
14 |
+
.chroma
|
15 |
+
*.ini
|
16 |
+
local_dump/
|
17 |
+
.pytest_cache/
|
configs/local_launch.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# launch config for local environment
|
2 |
+
local_dump: ./local_dump
|
3 |
+
# dictionary_path: ./domain_dict
|
4 |
+
environ: local
|
configs/task_config.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# configuration for each task
|
2 |
+
source_lang: EN
|
3 |
+
target_lang: ZH
|
4 |
+
field: General
|
5 |
+
|
6 |
+
# ASR config
|
7 |
+
ASR:
|
8 |
+
ASR_model: whisper
|
9 |
+
whisper_config:
|
10 |
+
whisper_model: tiny
|
11 |
+
method: stable
|
12 |
+
|
13 |
+
# pre-process module config
|
14 |
+
pre_process:
|
15 |
+
sentence_form: True
|
16 |
+
spell_check: False
|
17 |
+
term_correct: True
|
18 |
+
|
19 |
+
# Translation module config
|
20 |
+
translation:
|
21 |
+
model: gpt-4
|
22 |
+
chunk_size: 1000
|
23 |
+
|
24 |
+
# post-process module config
|
25 |
+
post_process:
|
26 |
+
check_len_and_split: True
|
27 |
+
remove_trans_punctuation: True
|
28 |
+
|
29 |
+
# output type that user receive
|
30 |
+
output_type:
|
31 |
+
subtitle: srt
|
32 |
+
video: True
|
33 |
+
bilingual: True
|
34 |
+
|
35 |
+
|
dict_util.py
CHANGED
@@ -52,4 +52,27 @@ with open("../test.csv", "w", encoding='utf-8') as w:
|
|
52 |
export_csv_dict(term_dict_sc2,w)
|
53 |
|
54 |
## for load pickle, just:
|
55 |
-
# pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
export_csv_dict(term_dict_sc2,w)
|
53 |
|
54 |
## for load pickle, just:
|
55 |
+
# pickle.load(f)
|
56 |
+
|
57 |
+
|
58 |
+
def form_dict(src_dict:list, tgt_dict:list) -> dict:
|
59 |
+
final_dict = {}
|
60 |
+
for idx, value in enumerate(src_dict):
|
61 |
+
for item in value:
|
62 |
+
final_dict.update({item:tgt_dict[idx]})
|
63 |
+
return final_dict
|
64 |
+
|
65 |
+
|
66 |
+
class term_dict(dict):
|
67 |
+
def __init__(self, path, src_lang, tgt_lang) -> None:
|
68 |
+
with open(f"{path}/{src_lang}.csv", 'r', encoding="utf-8") as file:
|
69 |
+
src_dict = list(csv.reader(file, delimiter=","))
|
70 |
+
with open(f"{path}/{tgt_lang}.csv", 'r', encoding="utf-8") as file:
|
71 |
+
tgt_dict = list(csv.reader(file, delimiter="," ))
|
72 |
+
super().__init__(form_dict(src_dict, tgt_dict))
|
73 |
+
|
74 |
+
|
75 |
+
def get(self, key:str) -> str:
|
76 |
+
word = self[key][randint(0,len(self[key])-1)]
|
77 |
+
return word
|
78 |
+
|
domain_dict/SC2/EN.csv
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
barracks
|
2 |
+
zerg
|
3 |
+
protoss
|
4 |
+
terran
|
5 |
+
engineering bay,engin bay
|
6 |
+
forge
|
7 |
+
blink
|
8 |
+
evolution chamber
|
9 |
+
cybernetics core,cybercore
|
10 |
+
enhanced shockwaves
|
11 |
+
gravitic boosters
|
12 |
+
armory
|
13 |
+
robotics bay,robo bay
|
14 |
+
twilight council,twilight
|
15 |
+
fusion core
|
16 |
+
fleet beacon
|
17 |
+
factory
|
18 |
+
ghost academy
|
19 |
+
infestation pit
|
20 |
+
robotics facility,robo
|
21 |
+
stargate
|
22 |
+
starport
|
23 |
+
archon
|
24 |
+
smart servos
|
25 |
+
gateway
|
26 |
+
warpgate
|
27 |
+
immortal
|
28 |
+
zealot
|
29 |
+
nydus network
|
30 |
+
nydus worm
|
31 |
+
hydralisk,hydra
|
32 |
+
grooved spines
|
33 |
+
muscular augments
|
34 |
+
hydralisk den,hydra den
|
35 |
+
planetary fortress
|
36 |
+
battle cruiser
|
37 |
+
weapon refit
|
38 |
+
brood lord
|
39 |
+
broodling
|
40 |
+
greater spire
|
41 |
+
anabolic synthesis
|
42 |
+
cyclone
|
43 |
+
bunker
|
domain_dict/SC2/ZH.csv
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
兵营
|
2 |
+
虫族
|
3 |
+
神族
|
4 |
+
人族
|
5 |
+
工程站,BE
|
6 |
+
BF,锻炉
|
7 |
+
闪现
|
8 |
+
进化腔
|
9 |
+
BY,赛博核心
|
10 |
+
EMP范围
|
11 |
+
ob速度
|
12 |
+
军械库
|
13 |
+
机械研究所,VB
|
14 |
+
光影议会,VC
|
15 |
+
聚变芯体
|
16 |
+
舰队航标
|
17 |
+
重工厂
|
18 |
+
幽灵军校
|
19 |
+
感染深渊
|
20 |
+
VR,机械台
|
21 |
+
神族VS,星门
|
22 |
+
星港,人族VS
|
23 |
+
白球
|
24 |
+
变形加速
|
25 |
+
传送门
|
26 |
+
折跃门
|
27 |
+
不朽
|
28 |
+
叉叉
|
29 |
+
虫道网络
|
30 |
+
坑道虫
|
31 |
+
刺蛇
|
32 |
+
刺蛇射程
|
33 |
+
刺蛇速度
|
34 |
+
刺蛇塔
|
35 |
+
大地堡,行星要塞
|
36 |
+
大和
|
37 |
+
大和炮
|
38 |
+
大龙
|
39 |
+
巢虫
|
40 |
+
大龙塔
|
41 |
+
大牛速度
|
42 |
+
导弹车
|
43 |
+
地堡
|
entries/__init_lib_path.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
def add_path(custom_path):
|
5 |
+
if custom_path not in sys.path: sys.path.insert(0, custom_path)
|
6 |
+
|
7 |
+
this_dir = os.path.dirname(__file__)
|
8 |
+
|
9 |
+
lib_path = os.path.join(this_dir, '..')
|
10 |
+
add_path(lib_path)
|
entries/app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import __init_lib_path
|
2 |
+
import gradio as gr
|
3 |
+
from src.task import Task
|
4 |
+
import logging
|
5 |
+
from yaml import Loader, Dumper, load, dump
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
from datetime import datetime
|
9 |
+
import shutil
|
10 |
+
from uuid import uuid4
|
11 |
+
|
12 |
+
launch_config = "./configs/local_launch.yaml"
|
13 |
+
task_config = './configs/task_config.yaml'
|
14 |
+
|
15 |
+
def init(output_type, src_lang, tgt_lang, domain):
|
16 |
+
launch_cfg = load(open(launch_config), Loader=Loader)
|
17 |
+
task_cfg = load(open(task_config), Loader=Loader)
|
18 |
+
|
19 |
+
# overwrite config file
|
20 |
+
task_cfg["source_lang"] = src_lang
|
21 |
+
task_cfg["target_lang"] = tgt_lang
|
22 |
+
task_cfg["field"] = domain
|
23 |
+
|
24 |
+
if "Video File" in output_type:
|
25 |
+
task_cfg["output_type"]["video"] = True
|
26 |
+
else:
|
27 |
+
task_cfg["output_type"]["video"] = False
|
28 |
+
|
29 |
+
if "Bilingual" in output_type:
|
30 |
+
task_cfg["output_type"]["bilingual"] = True
|
31 |
+
else:
|
32 |
+
task_cfg["output_type"]["bilingual"] = False
|
33 |
+
|
34 |
+
if ".ass output" in output_type:
|
35 |
+
task_cfg["output_type"]["subtitle"] = "ass"
|
36 |
+
else:
|
37 |
+
task_cfg["output_type"]["subtitle"] = "srt"
|
38 |
+
|
39 |
+
# initialize dir
|
40 |
+
local_dir = Path(launch_cfg['local_dump'])
|
41 |
+
if not local_dir.exists():
|
42 |
+
local_dir.mkdir(parents=False, exist_ok=False)
|
43 |
+
|
44 |
+
# get task id
|
45 |
+
task_id = str(uuid4())
|
46 |
+
|
47 |
+
# create locak dir for the task
|
48 |
+
task_dir = local_dir.joinpath(f"task_{task_id}")
|
49 |
+
task_dir.mkdir(parents=False, exist_ok=False)
|
50 |
+
task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
|
51 |
+
|
52 |
+
# logging setting
|
53 |
+
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
54 |
+
logging.basicConfig(level=logging.INFO, format=logfmt, handlers=[
|
55 |
+
logging.FileHandler(
|
56 |
+
"{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
|
57 |
+
'w', encoding='utf-8')])
|
58 |
+
return task_id, task_dir, task_cfg
|
59 |
+
|
60 |
+
def process_input(video_file, youtube_link, src_lang, tgt_lang, domain, output_type):
|
61 |
+
task_id, task_dir, task_cfg = init(output_type, src_lang, tgt_lang, domain)
|
62 |
+
if youtube_link:
|
63 |
+
task = Task.fromYoutubeLink(youtube_link, task_id, task_dir, task_cfg)
|
64 |
+
task.run()
|
65 |
+
return task.result
|
66 |
+
elif video_file is not None:
|
67 |
+
task = Task.fromVideoFile(video_file, task_id, task_dir, task_cfg)
|
68 |
+
task.run()
|
69 |
+
return task.result
|
70 |
+
else:
|
71 |
+
return None
|
72 |
+
|
73 |
+
demo = gr.Interface(fn=process_input,
|
74 |
+
inputs=[
|
75 |
+
gr.components.Video(label="Upload a video"),
|
76 |
+
gr.components.Textbox(label="Or enter a YouTube URL"),
|
77 |
+
gr.components.Dropdown(choices=["EN", "ZH"], label="Select Source Language"),
|
78 |
+
gr.components.Dropdown(choices=["ZH", "EN"], label="Select Target Language"),
|
79 |
+
gr.components.Dropdown(choices=["General", "SC2"], label="Select Domain"),
|
80 |
+
gr.CheckboxGroup(["Video File", "Bilingual", ".ass output"], label="Output Settings", info="What do you want?"),
|
81 |
+
],
|
82 |
+
outputs=[
|
83 |
+
gr.components.Video(label="Processed Video")
|
84 |
+
],
|
85 |
+
title="ViDove: video translation toolkit demo",
|
86 |
+
description="Upload a video or enter a YouTube URL."
|
87 |
+
)
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
demo.launch()
|
entries/run.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import __init_lib_path
|
2 |
+
import logging
|
3 |
+
from yaml import Loader, Dumper, load, dump
|
4 |
+
from src.task import Task
|
5 |
+
import openai
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
from datetime import datetime
|
10 |
+
import shutil
|
11 |
+
from uuid import uuid4
|
12 |
+
|
13 |
+
"""
|
14 |
+
Main entry for terminal environment.
|
15 |
+
Use it for debug and development purpose.
|
16 |
+
Usage: python3 entries/run.py [-h] [--link LINK] [--video_file VIDEO_FILE] [--audio_file AUDIO_FILE] [--srt_file SRT_FILE] [--continue CONTINUE]
|
17 |
+
[--launch_cfg LAUNCH_CFG] [--task_cfg TASK_CFG]
|
18 |
+
"""
|
19 |
+
|
20 |
+
def parse_args():
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
23 |
+
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
|
24 |
+
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
|
25 |
+
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str, required=False)
|
26 |
+
parser.add_argument("--continue", help="task_id that need to continue", default=None, type=str, required=False) # need implement
|
27 |
+
parser.add_argument("--launch_cfg", help="launch config path", default='./configs/local_launch.yaml', type=str, required=False)
|
28 |
+
parser.add_argument("--task_cfg", help="task config path", default='./configs/task_config.yaml', type=str, required=False)
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
return args
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
# read args and configs
|
35 |
+
args = parse_args()
|
36 |
+
launch_cfg = load(open(args.launch_cfg), Loader=Loader)
|
37 |
+
task_cfg = load(open(args.task_cfg), Loader=Loader)
|
38 |
+
|
39 |
+
# initialize dir
|
40 |
+
local_dir = Path(launch_cfg['local_dump'])
|
41 |
+
if not local_dir.exists():
|
42 |
+
local_dir.mkdir(parents=False, exist_ok=False)
|
43 |
+
|
44 |
+
# get task id
|
45 |
+
task_id = str(uuid4())
|
46 |
+
|
47 |
+
# create locak dir for the task
|
48 |
+
task_dir = local_dir.joinpath(f"task_{task_id}")
|
49 |
+
task_dir.mkdir(parents=False, exist_ok=False)
|
50 |
+
task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
|
51 |
+
|
52 |
+
# logging setting
|
53 |
+
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
54 |
+
logging.basicConfig(level=logging.INFO, format=logfmt, handlers=[
|
55 |
+
logging.FileHandler(
|
56 |
+
"{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
|
57 |
+
'w', encoding='utf-8')])
|
58 |
+
|
59 |
+
# Task create
|
60 |
+
if args.link is not None:
|
61 |
+
try:
|
62 |
+
task = Task.fromYoutubeLink(args.link, task_id, task_dir, task_cfg)
|
63 |
+
except:
|
64 |
+
shutil.rmtree(task_dir)
|
65 |
+
raise RuntimeError("failed to create task from youtube link")
|
66 |
+
elif args.video_file is not None:
|
67 |
+
try:
|
68 |
+
task = Task.fromVideoFile(args.video_file, task_id, task_dir, task_cfg)
|
69 |
+
except:
|
70 |
+
shutil.rmtree(task_dir)
|
71 |
+
raise RuntimeError("failed to create task from youtube link")
|
72 |
+
elif args.audio_file is not None:
|
73 |
+
try:
|
74 |
+
task = Task.fromVideoFile(args.audio_file, task_id, task_dir, task_cfg)
|
75 |
+
except:
|
76 |
+
shutil.rmtree(task_dir)
|
77 |
+
raise RuntimeError("failed to create task from youtube link")
|
78 |
+
|
79 |
+
# add task to the status queue
|
80 |
+
task.run()
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
entries/web_backend.py
ADDED
File without changes
|
requirement.txt
CHANGED
@@ -5,6 +5,7 @@ attrs==22.2.0
|
|
5 |
certifi==2022.12.7
|
6 |
charset-normalizer==3.1.0
|
7 |
ffmpeg-python==0.2.0
|
|
|
8 |
filelock==3.10.0
|
9 |
frozenlist==1.3.3
|
10 |
future==0.18.3
|
@@ -23,12 +24,13 @@ openai-whisper @ git+https://github.com/openai/whisper.git@6dea21fd7f7253bfe450f
|
|
23 |
panda==0.3.1
|
24 |
pandas==1.5.3
|
25 |
python-dateutil==2.8.2
|
26 |
-
pytube==
|
27 |
pytube3==9.6.4
|
28 |
pytz==2022.7.1
|
29 |
regex==2022.10.31
|
30 |
requests==2.28.2
|
31 |
six==1.16.0
|
|
|
32 |
sympy==1.11.1
|
33 |
tiktoken==0.3.1
|
34 |
torch==2.0.0
|
|
|
5 |
certifi==2022.12.7
|
6 |
charset-normalizer==3.1.0
|
7 |
ffmpeg-python==0.2.0
|
8 |
+
Flask==2.3.3
|
9 |
filelock==3.10.0
|
10 |
frozenlist==1.3.3
|
11 |
future==0.18.3
|
|
|
24 |
panda==0.3.1
|
25 |
pandas==1.5.3
|
26 |
python-dateutil==2.8.2
|
27 |
+
pytube==15.0.0
|
28 |
pytube3==9.6.4
|
29 |
pytz==2022.7.1
|
30 |
regex==2022.10.31
|
31 |
requests==2.28.2
|
32 |
six==1.16.0
|
33 |
+
stable-ts==2.9.0
|
34 |
sympy==1.11.1
|
35 |
tiktoken==0.3.1
|
36 |
torch==2.0.0
|
src/Pigeon.py
CHANGED
@@ -317,7 +317,7 @@ class Pigeon(object):
|
|
317 |
logging.info("--------------------Start Preprocessing SRT class--------------------")
|
318 |
self.srt.write_srt_file_src(self.srt_path)
|
319 |
self.srt.form_whole_sentence()
|
320 |
-
self.srt.spell_check_term()
|
321 |
self.srt.correct_with_force_term()
|
322 |
processed_srt_file_en = str(Path(self.srt_path).with_suffix('')) + '_processed.srt'
|
323 |
self.srt.write_srt_file_src(processed_srt_file_en)
|
|
|
317 |
logging.info("--------------------Start Preprocessing SRT class--------------------")
|
318 |
self.srt.write_srt_file_src(self.srt_path)
|
319 |
self.srt.form_whole_sentence()
|
320 |
+
# self.srt.spell_check_term()
|
321 |
self.srt.correct_with_force_term()
|
322 |
processed_srt_file_en = str(Path(self.srt_path).with_suffix('')) + '_processed.srt'
|
323 |
self.srt.write_srt_file_src(processed_srt_file_en)
|
src/preprocess/audio_extract.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
|
6 |
+
def extract_audio(local_video_path: str, save_dir_path: str = "./downloads/audio") -> str:
|
7 |
+
if os.name == 'nt':
|
8 |
+
NotImplementedError("Filename extraction on Windows not yet implemented")
|
9 |
+
|
10 |
+
out_file_name = os.path.basename(local_video_path)
|
11 |
+
audio_path_out = save_dir_path.join("/").join(out_file_name)
|
12 |
+
subprocess.run(['ffmpeg', '-i', local_video_path, '-f', 'mp3', '-ab', '192000', '-vn', audio_path_out])
|
13 |
+
return audio_path_out
|
src/preprocess/video_download.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytube import YouTube
|
2 |
+
import logging
|
3 |
+
|
4 |
+
def download_youtube_to_local_file(youtube_url: str, local_dir_path: str = "./downloads") -> str:
|
5 |
+
yt = YouTube(youtube_url)
|
6 |
+
try:
|
7 |
+
audio = yt.streams.filter(only_audio=True, file_extension='mp4').order_by('abr').desc().first()
|
8 |
+
# video = yt.streams.filter(file_extension='mp4').order_by('resolution').asc().first()
|
9 |
+
if audio:
|
10 |
+
saved_audio = audio.download(output_path=local_dir_path.join("/audio"))
|
11 |
+
logging.info(f"Audio download successful: {saved_audio}")
|
12 |
+
return saved_audio
|
13 |
+
else:
|
14 |
+
logging.error(f"Audio stream not found in {youtube_url}")
|
15 |
+
raise f"Audio stream not found in {youtube_url}"
|
16 |
+
except Exception as e:
|
17 |
+
# print("Connection Error: ", end='')
|
18 |
+
print(e)
|
19 |
+
raise e
|
20 |
+
|
src/srt_util/srt.py
CHANGED
@@ -7,10 +7,59 @@ from datetime import timedelta
|
|
7 |
import logging
|
8 |
import openai
|
9 |
from tqdm import tqdm
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class SrtSegment(object):
|
13 |
-
def __init__(self, *args) -> None:
|
|
|
|
|
|
|
14 |
if isinstance(args[0], dict):
|
15 |
segment = args[0]
|
16 |
self.start = segment['start']
|
@@ -54,6 +103,7 @@ class SrtSegment(object):
|
|
54 |
self.translation = ""
|
55 |
else:
|
56 |
self.translation = args[0][3]
|
|
|
57 |
|
58 |
def merge_seg(self, seg):
|
59 |
"""
|
@@ -83,12 +133,14 @@ class SrtSegment(object):
|
|
83 |
|
84 |
def remove_trans_punc(self) -> None:
|
85 |
"""
|
86 |
-
remove
|
87 |
:return: None
|
88 |
"""
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
|
93 |
def __str__(self) -> str:
|
94 |
return f'{self.duration}\n{self.source_text}\n\n'
|
@@ -101,11 +153,25 @@ class SrtSegment(object):
|
|
101 |
|
102 |
|
103 |
class SrtScript(object):
|
104 |
-
def __init__(self, segments) -> None:
|
105 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
@classmethod
|
108 |
-
def parse_from_srt_file(cls, path: str):
|
109 |
with open(path, 'r', encoding="utf-8") as f:
|
110 |
script_lines = [line.rstrip() for line in f.readlines()]
|
111 |
bilingual = False
|
@@ -119,7 +185,7 @@ class SrtScript(object):
|
|
119 |
for i in range(0, len(script_lines), 4):
|
120 |
segments.append(list(script_lines[i:i + 4]))
|
121 |
|
122 |
-
return cls(segments)
|
123 |
|
124 |
def merge_segs(self, idx_list) -> SrtSegment:
|
125 |
"""
|
@@ -147,9 +213,10 @@ class SrtScript(object):
|
|
147 |
logging.info("Forming whole sentences...")
|
148 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
149 |
sentence = []
|
|
|
150 |
# Get each entire sentence of distinct segments, fill indices to merge_list
|
151 |
for i, seg in enumerate(self.segments):
|
152 |
-
if seg.source_text[-1] in
|
153 |
sentence.append(i)
|
154 |
merge_list.append(sentence)
|
155 |
sentence = []
|
@@ -184,19 +251,20 @@ class SrtScript(object):
|
|
184 |
src_text += '\n\n'
|
185 |
|
186 |
def inner_func(target, input_str):
|
|
|
187 |
response = openai.ChatCompletion.create(
|
188 |
-
# model=model,
|
189 |
model="gpt-4",
|
190 |
messages=[
|
191 |
{"role": "system",
|
192 |
-
"content": "
|
193 |
-
{"role": "system", "content": "
|
194 |
-
{"role": "user", "content": '
|
195 |
],
|
196 |
temperature=0.15
|
197 |
)
|
198 |
return response['choices'][0]['message']['content'].strip()
|
199 |
|
|
|
200 |
lines = translate.split('\n\n')
|
201 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
202 |
count = 0
|
@@ -204,28 +272,27 @@ class SrtScript(object):
|
|
204 |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
205 |
count += 1
|
206 |
print("Solving Unmatched Lines|iteration {}".format(count))
|
|
|
207 |
|
208 |
flag = True
|
209 |
while flag:
|
210 |
flag = False
|
211 |
-
# print("translate:")
|
212 |
-
# print(translate)
|
213 |
try:
|
214 |
-
# print("target")
|
215 |
-
# print(end_seg_id - start_seg_id + 1)
|
216 |
translate = inner_func(end_seg_id - start_seg_id + 1, translate)
|
217 |
except Exception as e:
|
218 |
print("An error has occurred during solving unmatched lines:", e)
|
219 |
print("Retrying...")
|
|
|
|
|
220 |
flag = True
|
221 |
lines = translate.split('\n')
|
222 |
-
# print("result")
|
223 |
-
# print(len(lines))
|
224 |
|
225 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
226 |
solved = False
|
227 |
print("Failed Solving unmatched lines, Manually parse needed")
|
|
|
228 |
|
|
|
229 |
if not os.path.exists("./logs"):
|
230 |
os.mkdir("./logs")
|
231 |
if video_link:
|
@@ -244,7 +311,7 @@ class SrtScript(object):
|
|
244 |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
|
245 |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
|
246 |
len(self.segments)) + ',' + video_name + "\n")
|
247 |
-
print(lines)
|
248 |
|
249 |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
|
250 |
# naive way to due with merge translation problem
|
@@ -262,24 +329,27 @@ class SrtScript(object):
|
|
262 |
|
263 |
def split_seg(self, seg, text_threshold, time_threshold):
|
264 |
# evenly split seg to 2 parts and add new seg into self.segments
|
265 |
-
|
266 |
# ignore the initial comma to solve the recursion problem
|
|
|
|
|
|
|
267 |
if len(seg.source_text) > 2:
|
268 |
-
if seg.source_text[:2] ==
|
269 |
seg.source_text = seg.source_text[2:]
|
270 |
-
if seg.translation[0] ==
|
271 |
seg.translation = seg.translation[1:]
|
272 |
|
273 |
source_text = seg.source_text
|
274 |
translation = seg.translation
|
275 |
|
276 |
# split the text based on commas
|
277 |
-
src_commas = [m.start() for m in re.finditer(
|
278 |
-
trans_commas = [m.start() for m in re.finditer(
|
279 |
if len(src_commas) != 0:
|
280 |
src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
|
281 |
len(src_commas) // 2 - 1]
|
282 |
else:
|
|
|
283 |
src_space = [m.start() for m in re.finditer(' ', source_text)]
|
284 |
if len(src_space) > 0:
|
285 |
src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
|
@@ -315,14 +385,14 @@ class SrtScript(object):
|
|
315 |
seg1_dict['text'] = src_seg1
|
316 |
seg1_dict['start'] = start_seg1
|
317 |
seg1_dict['end'] = end_seg1
|
318 |
-
seg1 = SrtSegment(seg1_dict)
|
319 |
seg1.translation = trans_seg1
|
320 |
|
321 |
seg2_dict = {}
|
322 |
seg2_dict['text'] = src_seg2
|
323 |
seg2_dict['start'] = start_seg2
|
324 |
seg2_dict['end'] = end_seg2
|
325 |
-
seg2 = SrtSegment(seg2_dict)
|
326 |
seg2.translation = trans_seg2
|
327 |
|
328 |
result_list = []
|
@@ -353,8 +423,6 @@ class SrtScript(object):
|
|
353 |
self.segments = segments
|
354 |
logging.info("check_len_and_split finished")
|
355 |
|
356 |
-
pass
|
357 |
-
|
358 |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
359 |
# DEPRECATED
|
360 |
# if sentence length >= text_threshold, split this segments to two
|
@@ -376,22 +444,24 @@ class SrtScript(object):
|
|
376 |
def correct_with_force_term(self):
|
377 |
## force term correction
|
378 |
logging.info("performing force term correction")
|
379 |
-
# load term dictionary
|
380 |
-
with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
381 |
-
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
382 |
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
comp_dict = []
|
397 |
|
@@ -425,6 +495,12 @@ class SrtScript(object):
|
|
425 |
|
426 |
def spell_check_term(self):
|
427 |
logging.info("performing spell check")
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
import enchant
|
429 |
dict = enchant.Dict('en_US')
|
430 |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
|
@@ -532,3 +608,27 @@ class SrtScript(object):
|
|
532 |
f.write(f'{i + idx}\n')
|
533 |
f.write(seg.get_bilingual_str())
|
534 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import logging
|
8 |
import openai
|
9 |
from tqdm import tqdm
|
10 |
+
import dict_util
|
11 |
+
|
12 |
+
# punctuation dictionary for supported languages
|
13 |
+
punctuation_dict = {
|
14 |
+
"EN": {
|
15 |
+
"punc_str": ". , ? ! : ; - ( ) [ ] { }",
|
16 |
+
"comma": ", ",
|
17 |
+
"sentence_end": [".", "!", "?", ";"]
|
18 |
+
},
|
19 |
+
"ES": {
|
20 |
+
"punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿",
|
21 |
+
"comma": ", ",
|
22 |
+
"sentence_end": [".", "!", "?", ";", "¡", "¿"]
|
23 |
+
},
|
24 |
+
"FR": {
|
25 |
+
"punc_str": ".,?!:;«»—",
|
26 |
+
"comma": ", ",
|
27 |
+
"sentence_end": [".", "!", "?", ";"]
|
28 |
+
},
|
29 |
+
"DE": {
|
30 |
+
"punc_str": ".,?!:;„“–",
|
31 |
+
"comma": ", ",
|
32 |
+
"sentence_end": [".", "!", "?", ";"]
|
33 |
+
},
|
34 |
+
"RU": {
|
35 |
+
"punc_str": ".,?!:;-«»—",
|
36 |
+
"comma": ", ",
|
37 |
+
"sentence_end": [".", "!", "?", ";"]
|
38 |
+
},
|
39 |
+
"ZH": {
|
40 |
+
"punc_str": "。,?!:;()",
|
41 |
+
"comma": ",",
|
42 |
+
"sentence_end": ["。", "!", "?"]
|
43 |
+
},
|
44 |
+
"JA": {
|
45 |
+
"punc_str": "。、?!:;()",
|
46 |
+
"comma": "、",
|
47 |
+
"sentence_end": ["。", "!", "?"]
|
48 |
+
},
|
49 |
+
"AR": {
|
50 |
+
"punc_str": ".,?!:;-()[]،؛ ؟ «»",
|
51 |
+
"comma": "، ",
|
52 |
+
"sentence_end": [".", "!", "?", ";", "؟"]
|
53 |
+
},
|
54 |
+
}
|
55 |
+
|
56 |
+
dict_path = "./domain_dict"
|
57 |
|
58 |
class SrtSegment(object):
|
59 |
+
def __init__(self, src_lang, tgt_lang, *args) -> None:
|
60 |
+
self.src_lang = src_lang
|
61 |
+
self.tgt_lang = tgt_lang
|
62 |
+
|
63 |
if isinstance(args[0], dict):
|
64 |
segment = args[0]
|
65 |
self.start = segment['start']
|
|
|
103 |
self.translation = ""
|
104 |
else:
|
105 |
self.translation = args[0][3]
|
106 |
+
|
107 |
|
108 |
def merge_seg(self, seg):
|
109 |
"""
|
|
|
133 |
|
134 |
def remove_trans_punc(self) -> None:
|
135 |
"""
|
136 |
+
remove punctuations in translation text
|
137 |
:return: None
|
138 |
"""
|
139 |
+
punc_str = punctuation_dict[self.tgt_lang]["punc_str"]
|
140 |
+
for punc in punc_str:
|
141 |
+
self.translation = self.translation.replace(punc, ' ')
|
142 |
+
# translator = str.maketrans(punc, ' ' * len(punc))
|
143 |
+
# self.translation = self.translation.translate(translator)
|
144 |
|
145 |
def __str__(self) -> str:
|
146 |
return f'{self.duration}\n{self.source_text}\n\n'
|
|
|
153 |
|
154 |
|
155 |
class SrtScript(object):
|
156 |
+
def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None:
|
157 |
+
self.domain = domain
|
158 |
+
self.src_lang = src_lang
|
159 |
+
self.tgt_lang = tgt_lang
|
160 |
+
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
|
161 |
+
|
162 |
+
if self.domain != "General":
|
163 |
+
if os.path.exists(f"{dict_path}/{self.domain}") and\
|
164 |
+
os.path.exists(f"{dict_path}/{self.domain}/{src_lang}.csv") and os.path.exists(f"{dict_path}/{self.domain}/{tgt_lang}.csv" ):
|
165 |
+
# TODO: load dictionary
|
166 |
+
self.dict = dict_util.term_dict(f"{dict_path}/{self.domain}", src_lang, tgt_lang)
|
167 |
+
...
|
168 |
+
else:
|
169 |
+
logging.error(f"domain {self.domain} or related dictionary({src_lang} or {tgt_lang}) doesn't exist, fallback to general domain, this will disable correct_with_force_term and spell_check_term")
|
170 |
+
self.domain = "General"
|
171 |
+
|
172 |
|
173 |
@classmethod
|
174 |
+
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
|
175 |
with open(path, 'r', encoding="utf-8") as f:
|
176 |
script_lines = [line.rstrip() for line in f.readlines()]
|
177 |
bilingual = False
|
|
|
185 |
for i in range(0, len(script_lines), 4):
|
186 |
segments.append(list(script_lines[i:i + 4]))
|
187 |
|
188 |
+
return cls(src_lang, tgt_lang, segments)
|
189 |
|
190 |
def merge_segs(self, idx_list) -> SrtSegment:
|
191 |
"""
|
|
|
213 |
logging.info("Forming whole sentences...")
|
214 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
215 |
sentence = []
|
216 |
+
ending_puncs = punctuation_dict[self.src_lang]["sentence_end"]
|
217 |
# Get each entire sentence of distinct segments, fill indices to merge_list
|
218 |
for i, seg in enumerate(self.segments):
|
219 |
+
if seg.source_text[-1] in ending_puncs and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
|
220 |
sentence.append(i)
|
221 |
merge_list.append(sentence)
|
222 |
sentence = []
|
|
|
251 |
src_text += '\n\n'
|
252 |
|
253 |
def inner_func(target, input_str):
|
254 |
+
# handling merge sentences issue.
|
255 |
response = openai.ChatCompletion.create(
|
|
|
256 |
model="gpt-4",
|
257 |
messages=[
|
258 |
{"role": "system",
|
259 |
+
"content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"},
|
260 |
+
{"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)},
|
261 |
+
{"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)}
|
262 |
],
|
263 |
temperature=0.15
|
264 |
)
|
265 |
return response['choices'][0]['message']['content'].strip()
|
266 |
|
267 |
+
# handling merge sentences issue.
|
268 |
lines = translate.split('\n\n')
|
269 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
270 |
count = 0
|
|
|
272 |
while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
|
273 |
count += 1
|
274 |
print("Solving Unmatched Lines|iteration {}".format(count))
|
275 |
+
logging.error("Solving Unmatched Lines|iteration {}".format(count))
|
276 |
|
277 |
flag = True
|
278 |
while flag:
|
279 |
flag = False
|
|
|
|
|
280 |
try:
|
|
|
|
|
281 |
translate = inner_func(end_seg_id - start_seg_id + 1, translate)
|
282 |
except Exception as e:
|
283 |
print("An error has occurred during solving unmatched lines:", e)
|
284 |
print("Retrying...")
|
285 |
+
logging.error("An error has occurred during solving unmatched lines:", e)
|
286 |
+
logging.error("Retrying...")
|
287 |
flag = True
|
288 |
lines = translate.split('\n')
|
|
|
|
|
289 |
|
290 |
if len(lines) < (end_seg_id - start_seg_id + 1):
|
291 |
solved = False
|
292 |
print("Failed Solving unmatched lines, Manually parse needed")
|
293 |
+
logging.error("Failed Solving unmatched lines, Manually parse needed")
|
294 |
|
295 |
+
# FIXME: put the error log in our log file
|
296 |
if not os.path.exists("./logs"):
|
297 |
os.mkdir("./logs")
|
298 |
if video_link:
|
|
|
311 |
log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
|
312 |
log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
|
313 |
len(self.segments)) + ',' + video_name + "\n")
|
314 |
+
# print(lines)
|
315 |
|
316 |
for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
|
317 |
# naive way to due with merge translation problem
|
|
|
329 |
|
330 |
def split_seg(self, seg, text_threshold, time_threshold):
|
331 |
# evenly split seg to 2 parts and add new seg into self.segments
|
|
|
332 |
# ignore the initial comma to solve the recursion problem
|
333 |
+
src_comma_str = punctuation_dict[self.src_lang]["comma"]
|
334 |
+
tgt_comma_str = punctuation_dict[self.tgt_lang]["comma"]
|
335 |
+
|
336 |
if len(seg.source_text) > 2:
|
337 |
+
if seg.source_text[:2] == src_comma_str:
|
338 |
seg.source_text = seg.source_text[2:]
|
339 |
+
if seg.translation[0] == tgt_comma_str:
|
340 |
seg.translation = seg.translation[1:]
|
341 |
|
342 |
source_text = seg.source_text
|
343 |
translation = seg.translation
|
344 |
|
345 |
# split the text based on commas
|
346 |
+
src_commas = [m.start() for m in re.finditer(src_comma_str, source_text)]
|
347 |
+
trans_commas = [m.start() for m in re.finditer(tgt_comma_str, translation)]
|
348 |
if len(src_commas) != 0:
|
349 |
src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
|
350 |
len(src_commas) // 2 - 1]
|
351 |
else:
|
352 |
+
# split the text based on spaces
|
353 |
src_space = [m.start() for m in re.finditer(' ', source_text)]
|
354 |
if len(src_space) > 0:
|
355 |
src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
|
|
|
385 |
seg1_dict['text'] = src_seg1
|
386 |
seg1_dict['start'] = start_seg1
|
387 |
seg1_dict['end'] = end_seg1
|
388 |
+
seg1 = SrtSegment(self.src_lang, self.tgt_lang, seg1_dict)
|
389 |
seg1.translation = trans_seg1
|
390 |
|
391 |
seg2_dict = {}
|
392 |
seg2_dict['text'] = src_seg2
|
393 |
seg2_dict['start'] = start_seg2
|
394 |
seg2_dict['end'] = end_seg2
|
395 |
+
seg2 = SrtSegment(self.src_lang, self.tgt_lang, seg2_dict)
|
396 |
seg2.translation = trans_seg2
|
397 |
|
398 |
result_list = []
|
|
|
423 |
self.segments = segments
|
424 |
logging.info("check_len_and_split finished")
|
425 |
|
|
|
|
|
426 |
def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
|
427 |
# DEPRECATED
|
428 |
# if sentence length >= text_threshold, split this segments to two
|
|
|
444 |
def correct_with_force_term(self):
|
445 |
## force term correction
|
446 |
logging.info("performing force term correction")
|
|
|
|
|
|
|
447 |
|
448 |
+
# check domain
|
449 |
+
if self.domain == "General":
|
450 |
+
logging.info("General domain could not perform correct_with_force_term. skip this step.")
|
451 |
+
pass
|
452 |
+
else:
|
453 |
+
keywords = list(self.dict.keys())
|
454 |
+
keywords.sort(key=lambda x: len(x), reverse=True)
|
455 |
+
|
456 |
+
for word in keywords:
|
457 |
+
for i, seg in enumerate(self.segments):
|
458 |
+
if word in seg.source_text.lower():
|
459 |
+
seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(self.dict.get(word)),
|
460 |
+
seg.source_text, flags=re.IGNORECASE)
|
461 |
+
logging.info(
|
462 |
+
"replace term: " + word + " --> " + self.dict.get(word) + " in time stamp {}".format(
|
463 |
+
i + 1))
|
464 |
+
logging.info("source text becomes: " + seg.source_text)
|
465 |
|
466 |
comp_dict = []
|
467 |
|
|
|
495 |
|
496 |
def spell_check_term(self):
|
497 |
logging.info("performing spell check")
|
498 |
+
|
499 |
+
# check domain
|
500 |
+
if self.domain == "General":
|
501 |
+
logging.info("General domain could not perform spell_check_term. skip this step.")
|
502 |
+
pass
|
503 |
+
|
504 |
import enchant
|
505 |
dict = enchant.Dict('en_US')
|
506 |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
|
|
|
608 |
f.write(f'{i + idx}\n')
|
609 |
f.write(seg.get_bilingual_str())
|
610 |
pass
|
611 |
+
|
612 |
+
def split_script(script_in, chunk_size=1000):
|
613 |
+
script_split = script_in.split('\n\n')
|
614 |
+
script_arr = []
|
615 |
+
range_arr = []
|
616 |
+
start = 1
|
617 |
+
end = 0
|
618 |
+
script = ""
|
619 |
+
for sentence in script_split:
|
620 |
+
if len(script) + len(sentence) + 1 <= chunk_size:
|
621 |
+
script += sentence + '\n\n'
|
622 |
+
end += 1
|
623 |
+
else:
|
624 |
+
range_arr.append((start, end))
|
625 |
+
start = end + 1
|
626 |
+
end += 1
|
627 |
+
script_arr.append(script.strip())
|
628 |
+
script = sentence + '\n\n'
|
629 |
+
if script.strip():
|
630 |
+
script_arr.append(script.strip())
|
631 |
+
range_arr.append((start, len(script_split) - 1))
|
632 |
+
|
633 |
+
assert len(script_arr) == len(range_arr)
|
634 |
+
return script_arr, range_arr
|
src/task.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
|
4 |
+
import openai
|
5 |
+
from pytube import YouTube
|
6 |
+
from os import getenv, getcwd
|
7 |
+
from pathlib import Path
|
8 |
+
from enum import Enum, auto
|
9 |
+
import logging
|
10 |
+
import subprocess
|
11 |
+
from src.srt_util.srt import SrtScript
|
12 |
+
from src.srt_util.srt2ass import srt2ass
|
13 |
+
from time import time, strftime, gmtime, sleep
|
14 |
+
from src.translators.translation import get_translation, prompt_selector
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import stable_whisper
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
"""
|
21 |
+
Youtube link
|
22 |
+
- link
|
23 |
+
- model
|
24 |
+
- output type
|
25 |
+
|
26 |
+
Video file
|
27 |
+
- path
|
28 |
+
- model
|
29 |
+
- output type
|
30 |
+
|
31 |
+
Audio file
|
32 |
+
- path
|
33 |
+
- model
|
34 |
+
- output type
|
35 |
+
|
36 |
+
"""
|
37 |
+
"""
|
38 |
+
TaskID
|
39 |
+
Progress: Enum
|
40 |
+
Computing resrouce status
|
41 |
+
SRT_Script : SrtScript
|
42 |
+
- input module -> initialize (ASR module)
|
43 |
+
- Pre-process
|
44 |
+
- Translation (%)
|
45 |
+
- Post process (time stamp)
|
46 |
+
- Output module: SRT_Script --> output(.srt)
|
47 |
+
- (Optional) mp4
|
48 |
+
"""
|
49 |
+
|
50 |
+
class TaskStatus(str, Enum):
|
51 |
+
CREATED = 'CREATED'
|
52 |
+
INITIALIZING_ASR = 'INITIALIZING_ASR'
|
53 |
+
PRE_PROCESSING = 'PRE_PROCESSING'
|
54 |
+
TRANSLATING = 'TRANSLATING'
|
55 |
+
POST_PROCESSING = 'POST_PROCESSING'
|
56 |
+
OUTPUT_MODULE = 'OUTPUT_MODULE'
|
57 |
+
|
58 |
+
|
59 |
+
class Task:
|
60 |
+
@property
|
61 |
+
def status(self):
|
62 |
+
with self.__status_lock:
|
63 |
+
return self.__status
|
64 |
+
|
65 |
+
@status.setter
|
66 |
+
def status(self, new_status):
|
67 |
+
with self.__status_lock:
|
68 |
+
self.__status = new_status
|
69 |
+
|
70 |
+
def __init__(self, task_id, task_local_dir, task_cfg):
|
71 |
+
self.__status_lock = threading.Lock()
|
72 |
+
self.__status = TaskStatus.CREATED
|
73 |
+
self.gpu_status = 0
|
74 |
+
openai.api_key = getenv("OPENAI_API_KEY")
|
75 |
+
self.task_id = task_id
|
76 |
+
|
77 |
+
self.task_local_dir = task_local_dir
|
78 |
+
self.ASR_setting = task_cfg["ASR"]
|
79 |
+
self.translation_setting = task_cfg["translation"]
|
80 |
+
self.translation_model = self.translation_setting["model"]
|
81 |
+
|
82 |
+
self.output_type = task_cfg["output_type"]
|
83 |
+
self.target_lang = task_cfg["target_lang"]
|
84 |
+
self.source_lang = task_cfg["source_lang"]
|
85 |
+
self.field = task_cfg["field"]
|
86 |
+
self.pre_setting = task_cfg["pre_process"]
|
87 |
+
self.post_setting = task_cfg["post_process"]
|
88 |
+
|
89 |
+
self.audio_path = None
|
90 |
+
self.SRT_Script = None
|
91 |
+
self.result = None
|
92 |
+
self.s_t = None
|
93 |
+
self.t_e = None
|
94 |
+
|
95 |
+
print(f"Task ID: {self.task_id}")
|
96 |
+
logging.info(f"Task ID: {self.task_id}")
|
97 |
+
logging.info(f"{self.source_lang} -> {self.target_lang} task in {self.field}")
|
98 |
+
logging.info(f"Translation Model: {self.translation_model}")
|
99 |
+
logging.info(f"subtitle_type: {self.output_type['subtitle']}")
|
100 |
+
logging.info(f"video_ouput: {self.output_type['video']}")
|
101 |
+
logging.info(f"bilingual_ouput: {self.output_type['bilingual']}")
|
102 |
+
logging.info("Pre-process setting:")
|
103 |
+
for key in self.pre_setting:
|
104 |
+
logging.info(f"{key}: {self.pre_setting[key]}")
|
105 |
+
logging.info("Post-process setting:")
|
106 |
+
for key in self.post_setting:
|
107 |
+
logging.info(f"{key}: {self.post_setting[key]}")
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg):
|
111 |
+
# convert to audio
|
112 |
+
logging.info("Task Creation method: Youtube Link")
|
113 |
+
return YoutubeTask(task_id, task_dir, task_cfg, youtube_url)
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def fromAudioFile(audio_path, task_id, task_dir, task_cfg):
|
117 |
+
# get audio path
|
118 |
+
logging.info("Task Creation method: Audio File")
|
119 |
+
return AudioTask(task_id, task_dir, task_cfg, audio_path)
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def fromVideoFile(video_path, task_id, task_dir, task_cfg):
|
123 |
+
# get audio path
|
124 |
+
logging.info("Task Creation method: Video File")
|
125 |
+
return VideoTask(task_id, task_dir, task_cfg, video_path)
|
126 |
+
|
127 |
+
# Module 1 ASR: audio --> SRT_script
|
128 |
+
def get_srt_class(self):
|
129 |
+
# Instead of using the script_en variable directly, we'll use script_input
|
130 |
+
# TODO: setup ASR module like translator
|
131 |
+
self.status = TaskStatus.INITIALIZING_ASR
|
132 |
+
self.t_s = time()
|
133 |
+
|
134 |
+
method = self.ASR_setting["whisper_config"]["method"]
|
135 |
+
whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
|
136 |
+
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id}_{self.source_lang}.srt")
|
137 |
+
if not Path.exists(src_srt_path):
|
138 |
+
# extract script from audio
|
139 |
+
logging.info("extract script from audio")
|
140 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
141 |
+
|
142 |
+
if method == "api":
|
143 |
+
with open(self.audio_path, 'rb') as audio_file:
|
144 |
+
transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt")
|
145 |
+
elif method == "stable":
|
146 |
+
model = stable_whisper.load_model(whisper_model, device)
|
147 |
+
transcript = model.transcribe(str(self.audio_path), regroup=False,
|
148 |
+
initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
|
149 |
+
(
|
150 |
+
transcript
|
151 |
+
.split_by_punctuation(['.', '。', '?'])
|
152 |
+
.merge_by_gap(.15, max_words=3)
|
153 |
+
.merge_by_punctuation([' '])
|
154 |
+
.split_by_punctuation(['.', '。', '?'])
|
155 |
+
)
|
156 |
+
transcript = transcript.to_dict()
|
157 |
+
|
158 |
+
# after get the transcript, release the gpu resource
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
|
161 |
+
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'], self.field)
|
162 |
+
# save the srt script to local
|
163 |
+
self.SRT_Script.write_srt_file_src(src_srt_path)
|
164 |
+
|
165 |
+
# Module 2: SRT preprocess: perform preprocess steps
|
166 |
+
def preprocess(self):
|
167 |
+
self.status = TaskStatus.PRE_PROCESSING
|
168 |
+
logging.info("--------------------Start Preprocessing SRT class--------------------")
|
169 |
+
if self.pre_setting["sentence_form"]:
|
170 |
+
self.SRT_Script.form_whole_sentence()
|
171 |
+
if self.pre_setting["spell_check"]:
|
172 |
+
self.SRT_Script.spell_check_term()
|
173 |
+
if self.pre_setting["term_correct"]:
|
174 |
+
self.SRT_Script.correct_with_force_term()
|
175 |
+
processed_srt_path_src = str(Path(self.task_local_dir) / f'{self.task_id}_processed.srt')
|
176 |
+
self.SRT_Script.write_srt_file_src(processed_srt_path_src)
|
177 |
+
|
178 |
+
if self.output_type["subtitle"] == "ass":
|
179 |
+
logging.info("write English .srt file to .ass")
|
180 |
+
assSub_src = srt2ass(processed_srt_path_src, "default", "No", "Modest")
|
181 |
+
logging.info('ASS subtitle saved as: ' + assSub_src)
|
182 |
+
self.script_input = self.SRT_Script.get_source_only()
|
183 |
+
pass
|
184 |
+
|
185 |
+
def update_translation_progress(self, new_progress):
|
186 |
+
if self.progress == TaskStatus.TRANSLATING:
|
187 |
+
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
188 |
+
|
189 |
+
# Module 3: perform srt translation
|
190 |
+
def translation(self):
|
191 |
+
logging.info("---------------------Start Translation--------------------")
|
192 |
+
prompt = prompt_selector(self.source_lang, self.target_lang, self.field)
|
193 |
+
get_translation(self.SRT_Script, self.translation_model, self.task_id, prompt, self.translation_setting['chunk_size'])
|
194 |
+
|
195 |
+
# Module 4: perform srt post process steps
|
196 |
+
def postprocess(self):
|
197 |
+
self.status = TaskStatus.POST_PROCESSING
|
198 |
+
|
199 |
+
logging.info("---------------------Start Post-processing SRT class---------------------")
|
200 |
+
if self.post_setting["check_len_and_split"]:
|
201 |
+
self.SRT_Script.check_len_and_split()
|
202 |
+
if self.post_setting["remove_trans_punctuation"]:
|
203 |
+
self.SRT_Script.remove_trans_punctuation()
|
204 |
+
logging.info("---------------------Post-processing SRT class finished---------------------")
|
205 |
+
|
206 |
+
# Module 5: output module
|
207 |
+
def output_render(self):
|
208 |
+
self.status = TaskStatus.OUTPUT_MODULE
|
209 |
+
video_out = self.output_type["video"]
|
210 |
+
subtitle_type = self.output_type["subtitle"]
|
211 |
+
is_bilingual = self.output_type["bilingual"]
|
212 |
+
|
213 |
+
results_dir =f"{self.task_local_dir}/results"
|
214 |
+
|
215 |
+
subtitle_path = f"{results_dir}/{self.task_id}_{self.target_lang}.srt"
|
216 |
+
self.SRT_Script.write_srt_file_translate(subtitle_path)
|
217 |
+
if is_bilingual:
|
218 |
+
subtitle_path = f"{results_dir}/{self.task_id}_{self.source_lang}_{self.target_lang}.srt"
|
219 |
+
self.SRT_Script.write_srt_file_bilingual(subtitle_path)
|
220 |
+
|
221 |
+
if subtitle_type == "ass":
|
222 |
+
logging.info("write .srt file to .ass")
|
223 |
+
subtitle_path = srt2ass(subtitle_path, "default", "No", "Modest")
|
224 |
+
logging.info('ASS subtitle saved as: ' + subtitle_path)
|
225 |
+
|
226 |
+
final_res = subtitle_path
|
227 |
+
|
228 |
+
# encode to .mp4 video file
|
229 |
+
if video_out and self.video_path is not None:
|
230 |
+
logging.info("encoding video file")
|
231 |
+
logging.info(f'ffmpeg comand: \nffmpeg -i {self.video_path} -vf "subtitles={subtitle_path}" {results_dir}/{self.task_id}.mp4')
|
232 |
+
subprocess.run(
|
233 |
+
["ffmpeg",
|
234 |
+
"-i", self.video_path,
|
235 |
+
"-vf", f"subtitles={subtitle_path}",
|
236 |
+
f"{results_dir}/{self.task_id}.mp4"])
|
237 |
+
final_res = f"{results_dir}/{self.task_id}.mp4"
|
238 |
+
|
239 |
+
self.t_e = time()
|
240 |
+
logging.info(
|
241 |
+
"Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s))))
|
242 |
+
return final_res
|
243 |
+
|
244 |
+
def run_pipeline(self):
|
245 |
+
self.get_srt_class()
|
246 |
+
self.preprocess()
|
247 |
+
self.translation()
|
248 |
+
self.postprocess()
|
249 |
+
self.result = self.output_render()
|
250 |
+
# print(self.result)
|
251 |
+
|
252 |
+
class YoutubeTask(Task):
|
253 |
+
def __init__(self, task_id, task_local_dir, task_cfg, youtube_url):
|
254 |
+
super().__init__(task_id, task_local_dir, task_cfg)
|
255 |
+
self.youtube_url = youtube_url
|
256 |
+
|
257 |
+
def run(self):
|
258 |
+
yt = YouTube(self.youtube_url)
|
259 |
+
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
260 |
+
|
261 |
+
if video:
|
262 |
+
video.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp4")
|
263 |
+
logging.info(f'Video Name: {video.default_filename}')
|
264 |
+
else:
|
265 |
+
raise FileNotFoundError(f" Video stream not found for link {self.youtube_url}")
|
266 |
+
|
267 |
+
audio = yt.streams.filter(only_audio=True).first()
|
268 |
+
if audio:
|
269 |
+
audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3")
|
270 |
+
else:
|
271 |
+
logging.info(" download audio failed, using ffmpeg to extract audio")
|
272 |
+
subprocess.run(
|
273 |
+
['ffmpeg', '-i', self.task_local_dir.joinpath(f"task_{self.task_id}.mp4"), '-f', 'mp3',
|
274 |
+
'-ab', '192000', '-vn', self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")])
|
275 |
+
logging.info("audio extraction finished")
|
276 |
+
|
277 |
+
self.video_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp4")
|
278 |
+
self.audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")
|
279 |
+
|
280 |
+
logging.info(f" Video File Dir: {self.video_path}")
|
281 |
+
logging.info(f" Audio File Dir: {self.audio_path}")
|
282 |
+
logging.info(" Data Prep Complete. Start pipeline")
|
283 |
+
|
284 |
+
super().run_pipeline()
|
285 |
+
|
286 |
+
class AudioTask(Task):
|
287 |
+
def __init__(self, task_id, task_local_dir, task_cfg, audio_path):
|
288 |
+
super().__init__(task_id, task_local_dir, task_cfg)
|
289 |
+
# TODO: check audio format
|
290 |
+
self.audio_path = audio_path
|
291 |
+
self.video_path = None
|
292 |
+
|
293 |
+
def run(self):
|
294 |
+
logging.info(f"Video File Dir: {self.video_path}")
|
295 |
+
logging.info(f"Audio File Dir: {self.audio_path}")
|
296 |
+
logging.info("Data Prep Complete. Start pipeline")
|
297 |
+
super().run_pipeline()
|
298 |
+
|
299 |
+
class VideoTask(Task):
|
300 |
+
def __init__(self, task_id, task_local_dir, task_cfg, video_path):
|
301 |
+
super().__init__(task_id, task_local_dir, task_cfg)
|
302 |
+
# TODO: check video format {.mp4}
|
303 |
+
new_video_path = f"{task_local_dir}/task_{self.task_id}.mp4"
|
304 |
+
print(new_video_path)
|
305 |
+
logging.info(f"Copy video file to: {new_video_path}")
|
306 |
+
shutil.copyfile(video_path, new_video_path)
|
307 |
+
self.video_path = new_video_path
|
308 |
+
|
309 |
+
def run(self):
|
310 |
+
logging.info("using ffmpeg to extract audio")
|
311 |
+
subprocess.run(
|
312 |
+
['ffmpeg', '-i', self.video_path, '-f', 'mp3',
|
313 |
+
'-ab', '192000', '-vn', self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")])
|
314 |
+
logging.info("audio extraction finished")
|
315 |
+
|
316 |
+
self.audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")
|
317 |
+
logging.info(f" Video File Dir: {self.video_path}")
|
318 |
+
logging.info(f" Audio File Dir: {self.audio_path}")
|
319 |
+
logging.info("Data Prep Complete. Start pipeline")
|
320 |
+
super().run_pipeline()
|
src/translators/LLM_task.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import openai
|
3 |
+
|
4 |
+
|
5 |
+
def LLM_task(model_name, input, task, temp = 0.15):
|
6 |
+
"""
|
7 |
+
Translates input sentence with desired LLM.
|
8 |
+
|
9 |
+
:param model_name: The name of the translation model to be used.
|
10 |
+
:param input: Sentence for translation.
|
11 |
+
:param task: Prompt.
|
12 |
+
:param temp: Model temperature.
|
13 |
+
"""
|
14 |
+
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
|
15 |
+
response = openai.ChatCompletion.create(
|
16 |
+
model=model_name,
|
17 |
+
messages=[
|
18 |
+
{"role": "system","content": task},
|
19 |
+
{"role": "user", "content": input}
|
20 |
+
],
|
21 |
+
temperature=temp
|
22 |
+
)
|
23 |
+
return response['choices'][0]['message']['content'].strip()
|
24 |
+
# Other LLM not implemented
|
25 |
+
else:
|
26 |
+
raise NotImplementedError
|
src/translators/__init__.py
ADDED
File without changes
|
src/translators/translation.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import getenv
|
2 |
+
import logging
|
3 |
+
from time import sleep
|
4 |
+
from tqdm import tqdm
|
5 |
+
from src.srt_util.srt import split_script
|
6 |
+
from .LLM_task import LLM_task
|
7 |
+
|
8 |
+
def get_translation(srt, model, video_name, prompt, chunk_size = 1000):
|
9 |
+
script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
|
10 |
+
translate(srt, script_arr, range_arr, model, video_name, task=prompt)
|
11 |
+
pass
|
12 |
+
|
13 |
+
def check_translation(sentence, translation):
|
14 |
+
"""
|
15 |
+
check merge sentence issue from openai translation
|
16 |
+
"""
|
17 |
+
sentence_count = sentence.count('\n\n') + 1
|
18 |
+
translation_count = translation.count('\n\n') + 1
|
19 |
+
|
20 |
+
if sentence_count != translation_count:
|
21 |
+
return False
|
22 |
+
else:
|
23 |
+
return True
|
24 |
+
|
25 |
+
# TODO{david}: prompts selector
|
26 |
+
def prompt_selector(src_lang, tgt_lang, domain):
|
27 |
+
language_map = {
|
28 |
+
"EN": "English",
|
29 |
+
"ZH": "Chinese",
|
30 |
+
}
|
31 |
+
src_lang = language_map[src_lang]
|
32 |
+
tgt_lang = language_map[tgt_lang]
|
33 |
+
prompt = f"""
|
34 |
+
you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
|
35 |
+
you will be provided with a segement in {src_lang} parsed by line, where your translation text should keep the original
|
36 |
+
meaning and the number of lines.
|
37 |
+
"""
|
38 |
+
return prompt
|
39 |
+
|
40 |
+
def translate(srt, script_arr, range_arr, model_name, video_name=None, attempts_count=5, task=None, temp = 0.15):
|
41 |
+
"""
|
42 |
+
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
43 |
+
|
44 |
+
This function takes a script array, a range array, a model name, a video name, and a video link as input. It iterates
|
45 |
+
through sentences and range in the script and range arrays. If the translation check fails for five times, the function
|
46 |
+
will attempt to resolve merge sentence issues and split the sentence into smaller tokens for a better translation.
|
47 |
+
|
48 |
+
:param srt: An instance of the Subtitle class representing the SRT file.
|
49 |
+
:param script_arr: A list of strings representing the original script sentences to be translated.
|
50 |
+
:param range_arr: A list of tuples representing the start and end positions of sentences in the script.
|
51 |
+
:param model_name: The name of the translation model to be used.
|
52 |
+
:param video_name: The name of the video.
|
53 |
+
:param attempts_count: Number of attemps of failures for unmatched sentences.
|
54 |
+
:param task: Prompt.
|
55 |
+
:param temp: Model temperature.
|
56 |
+
"""
|
57 |
+
|
58 |
+
if input is None:
|
59 |
+
raise Exception("Warning! No Input have passed to LLM!")
|
60 |
+
if task is None:
|
61 |
+
task = "你是一个翻译助理,你的任务是翻译视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
|
62 |
+
logging.info(f"translation prompt: {task}")
|
63 |
+
previous_length = 0
|
64 |
+
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|
65 |
+
# update the range based on previous length
|
66 |
+
range_ = (range_[0] + previous_length, range_[1] + previous_length)
|
67 |
+
# using chatgpt model
|
68 |
+
print(f"now translating sentences {range_}")
|
69 |
+
logging.info(f"now translating sentences {range_}")
|
70 |
+
flag = True
|
71 |
+
while flag:
|
72 |
+
flag = False
|
73 |
+
try:
|
74 |
+
translate = LLM_task(model_name, sentence, task, temp)
|
75 |
+
# detect merge sentence issue and try to solve for five times:
|
76 |
+
while not check_translation(sentence, translate) and attempts_count > 0:
|
77 |
+
translate = LLM_task(model_name, sentence, task, temp)
|
78 |
+
attempts_count -= 1
|
79 |
+
|
80 |
+
# if failure still happen, split into smaller tokens
|
81 |
+
if attempts_count == 0:
|
82 |
+
single_sentences = sentence.split("\n\n")
|
83 |
+
logging.info("merge sentence issue found for range", range_)
|
84 |
+
translate = ""
|
85 |
+
for i, single_sentence in enumerate(single_sentences):
|
86 |
+
if i == len(single_sentences) - 1:
|
87 |
+
translate += LLM_task(model_name,sentence,task,temp)
|
88 |
+
else:
|
89 |
+
translate += LLM_task(model_name,sentence,task,temp) + "\n\n"
|
90 |
+
logging.info("solved by individually translation!")
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
logging.debug("An error has occurred during translation:", e)
|
94 |
+
print("An error has occurred during translation:", e)
|
95 |
+
print("Retrying... the script will continue after 30 seconds.")
|
96 |
+
sleep(30)
|
97 |
+
flag = True
|
98 |
+
|
99 |
+
srt.set_translation(translate, range_, model_name, video_name)
|
src/web/api_specs.yaml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openapi: 3.0.3
|
2 |
+
info:
|
3 |
+
title: Pigeon AI
|
4 |
+
description: Pigeon AI
|
5 |
+
version: 1.0.0
|
6 |
+
servers:
|
7 |
+
- url: 'https'
|
8 |
+
paths:
|
9 |
+
/api/task:
|
10 |
+
post:
|
11 |
+
summary: Create a task
|
12 |
+
operationId: createTask
|
13 |
+
requestBody:
|
14 |
+
content:
|
15 |
+
application/json:
|
16 |
+
schema:
|
17 |
+
$ref: '#/components/schemas/youtubeLink'
|
18 |
+
responses:
|
19 |
+
'200':
|
20 |
+
description: OK
|
21 |
+
content:
|
22 |
+
application/json:
|
23 |
+
schema:
|
24 |
+
$ref: '#/components/schemas/task'
|
25 |
+
/api/task/{taskId}/status:
|
26 |
+
get:
|
27 |
+
summary: Get task status
|
28 |
+
operationId: getTask
|
29 |
+
parameters:
|
30 |
+
- name: taskId
|
31 |
+
in: path
|
32 |
+
required: true
|
33 |
+
description: task id
|
34 |
+
schema:
|
35 |
+
type: string
|
36 |
+
responses:
|
37 |
+
'200':
|
38 |
+
description: OK
|
39 |
+
content:
|
40 |
+
application/json:
|
41 |
+
schema:
|
42 |
+
$ref: '#/components/schemas/taskStatus'
|
43 |
+
'404':
|
44 |
+
description: Not Found
|
45 |
+
content:
|
46 |
+
application/json:
|
47 |
+
schema:
|
48 |
+
$ref: '#/components/schemas/error'
|
49 |
+
|
50 |
+
components:
|
51 |
+
schemas:
|
52 |
+
youtubeLink:
|
53 |
+
type: object
|
54 |
+
properties:
|
55 |
+
youtubeLink:
|
56 |
+
type: string
|
57 |
+
description: youtube link
|
58 |
+
example: https://www.youtube.com/watch?v=5qap5aO4i9A
|
59 |
+
task:
|
60 |
+
type: object
|
61 |
+
properties:
|
62 |
+
taskId:
|
63 |
+
type: string
|
64 |
+
description: task id generated by uuid
|
65 |
+
example: 7a765280-1a72-47e4-8747-8a38cdbaca91
|
66 |
+
taskStatus:
|
67 |
+
type: object
|
68 |
+
properties:
|
69 |
+
status:
|
70 |
+
type: string
|
71 |
+
description: task status
|
72 |
+
example: PROCESSING
|
73 |
+
error:
|
74 |
+
type: object
|
75 |
+
properties:
|
76 |
+
error:
|
77 |
+
type: string
|
78 |
+
description: error message
|
79 |
+
example: 'Invalid youtube link'
|
src/web/web.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from flask import Flask, request, jsonify
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from src.task import Task
|
5 |
+
from uuid import uuid4
|
6 |
+
|
7 |
+
app = Flask(__name__)
|
8 |
+
|
9 |
+
# Global thread pool
|
10 |
+
executor = ThreadPoolExecutor(max_workers=4) # Adjust max_workers as per your requirement
|
11 |
+
|
12 |
+
# thread safe task map to store task status
|
13 |
+
task_map = {}
|
14 |
+
|
15 |
+
@app.route('/api/task', methods=['POST'])
|
16 |
+
def create_task_youtube():
|
17 |
+
global task_map
|
18 |
+
data = request.get_json()
|
19 |
+
if not data or 'youtubeLink' not in data:
|
20 |
+
return jsonify({'error': 'YouTube link not provided'}), 400
|
21 |
+
youtube_link = data['youtubeLink']
|
22 |
+
launch_config = yaml.load(open("./configs/local_launch.yaml"), Loader=yaml.Loader)
|
23 |
+
task_id = str(uuid4())
|
24 |
+
task = Task.fromYoutubeLink(youtube_link, task_id, launch_config)
|
25 |
+
task_map[task_id] = task
|
26 |
+
# Submit task to thread pool
|
27 |
+
executor.submit(task.run)
|
28 |
+
|
29 |
+
return jsonify({'taskId': task.task_id})
|
30 |
+
|
31 |
+
@app.route('/api/task/<taskId>/status', methods=['GET'])
|
32 |
+
def get_task_status(taskId):
|
33 |
+
global task_map
|
34 |
+
if taskId not in task_map:
|
35 |
+
return jsonify({'error': 'Task not found'}), 404
|
36 |
+
return jsonify({'status': task_map[taskId].status})
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
app.run(debug=True)
|
tests/test_remove_punc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./src')
|
3 |
+
from srt_util.srt import SrtScript, SrtSegment
|
4 |
+
|
5 |
+
zh_test1 = "再次,如果你对一些福利感兴趣,你也可以。"
|
6 |
+
zh_en_test1 = "GG。Classic在我今年解说的最奇葩的系列赛中获得了胜利。"
|
7 |
+
|
8 |
+
def form_srt_class(src_lang, tgt_lang, source_text="", translation="", duration="00:00:00,740 --> 00:00:08,779"):
|
9 |
+
segment = [0, duration, source_text, translation, ""]
|
10 |
+
return SrtScript(src_lang, tgt_lang, [segment])
|
11 |
+
|
12 |
+
def test_zh():
|
13 |
+
srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_test1)
|
14 |
+
srt.remove_trans_punctuation()
|
15 |
+
assert srt.segments[0].translation == "再次 如果你对一些福利感兴趣 你也可以 "
|
16 |
+
|
17 |
+
def test_zh_en():
|
18 |
+
srt = form_srt_class(src_lang="EN", tgt_lang="ZH", translation=zh_en_test1)
|
19 |
+
srt.remove_trans_punctuation()
|
20 |
+
assert srt.segments[0].translation == "GG Classic在我今年解说的最奇葩的系列赛中获得了胜利 "
|
21 |
+
|