Spaces:
Sleeping
Sleeping
Merge branch 'SRT_cleanup' into eason/main
Browse filesFormer-commit-id: 3cedf7bb4e826122d3227968510ee9811a86bcb5
- doc/Installation.md +7 -0
- doc/struct.md +7 -0
- pipeline.py +68 -49
- srt_util/__init__.py +0 -0
- SRT.py → srt_util/srt.py +40 -46
- srt2ass.py → srt_util/srt2ass.py +0 -0
doc/Installation.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### **Recommended:**
|
2 |
+
We recommend you to configure your environment using [mamba](https://pypi.org/project/mamba/). The following packages are required:
|
3 |
+
```
|
4 |
+
openai
|
5 |
+
openai-whisper
|
6 |
+
|
7 |
+
```
|
doc/struct.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Structure of Repository
|
2 |
+
```
|
3 |
+
├── doc # Baseline implementation of SpMM algorithm.
|
4 |
+
├────── struct.md # Document of repository structure.
|
5 |
+
├── finetune_data #
|
6 |
+
└── README.md
|
7 |
+
```
|
pipeline.py
CHANGED
@@ -3,10 +3,10 @@ from pytube import YouTube
|
|
3 |
import argparse
|
4 |
import os
|
5 |
from tqdm import tqdm
|
6 |
-
from
|
7 |
import stable_whisper
|
8 |
import whisper
|
9 |
-
from srt2ass import srt2ass
|
10 |
import logging
|
11 |
from datetime import datetime
|
12 |
import torch
|
@@ -15,23 +15,29 @@ import subprocess
|
|
15 |
|
16 |
import time
|
17 |
|
|
|
18 |
def parse_args():
|
19 |
parser = argparse.ArgumentParser()
|
20 |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
21 |
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
|
22 |
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
|
23 |
-
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str,
|
|
|
24 |
parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
|
25 |
parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
|
26 |
-
parser.add_argument("--video_name",
|
27 |
-
|
|
|
|
|
|
|
28 |
parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
|
29 |
parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
|
30 |
parser.add_argument("-v", help="auto encode script with video", action='store_true')
|
31 |
args = parser.parse_args()
|
32 |
-
|
33 |
return args
|
34 |
|
|
|
35 |
def get_sources(args, download_path, result_path, video_name):
|
36 |
# get source audio
|
37 |
audio_path = None
|
@@ -59,9 +65,9 @@ def get_sources(args, download_path, result_path, video_name):
|
|
59 |
print("Error: Audio stream not found")
|
60 |
except Exception as e:
|
61 |
print("Connection Error")
|
62 |
-
print(e)
|
63 |
exit()
|
64 |
-
|
65 |
video_path = f'{download_path}/video/{video.default_filename}'
|
66 |
audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
|
67 |
audio_file = open(audio_path, "rb")
|
@@ -72,7 +78,7 @@ def get_sources(args, download_path, result_path, video_name):
|
|
72 |
video_path = args.video_file
|
73 |
|
74 |
if args.audio_file is not None:
|
75 |
-
audio_file= open(args.audio_file, "rb")
|
76 |
audio_path = args.audio_file
|
77 |
else:
|
78 |
output_audio_path = f'{download_path}/audio/{video_name}.mp3'
|
@@ -84,37 +90,41 @@ def get_sources(args, download_path, result_path, video_name):
|
|
84 |
os.mkdir(f'{result_path}/{video_name}')
|
85 |
|
86 |
if args.audio_file is not None:
|
87 |
-
audio_file= open(args.audio_file, "rb")
|
88 |
audio_path = args.audio_file
|
89 |
pass
|
90 |
|
91 |
return audio_path, audio_file, video_path, video_name
|
92 |
|
93 |
-
|
|
|
|
|
94 |
# Instead of using the script_en variable directly, we'll use script_input
|
95 |
-
if srt_file_en is not None:
|
96 |
-
srt =
|
97 |
else:
|
98 |
# using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
|
99 |
srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
|
100 |
if not os.path.exists(srt_file_en):
|
101 |
-
|
102 |
-
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
103 |
# use OpenAI API for transcribe
|
104 |
if method == "api":
|
105 |
-
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
106 |
|
107 |
-
|
108 |
elif method == "basic":
|
109 |
-
model = whisper.load_model(whisper_model,
|
|
|
110 |
transcript = model.transcribe(audio_path)
|
111 |
|
112 |
# use stable-whisper
|
113 |
elif method == "stable":
|
114 |
|
115 |
# use cuda if available
|
116 |
-
model = stable_whisper.load_model(whisper_model, device
|
117 |
-
transcript = model.transcribe(audio_path, regroup
|
|
|
118 |
(
|
119 |
transcript
|
120 |
.split_by_punctuation(['.', '。', '?'])
|
@@ -126,14 +136,15 @@ def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file =
|
|
126 |
else:
|
127 |
raise ValueError("invalid speech to text method")
|
128 |
|
129 |
-
srt =
|
130 |
|
131 |
else:
|
132 |
-
srt =
|
133 |
return srt_file_en, srt
|
134 |
|
|
|
135 |
# Split the video script by sentences and create chunks within the token limit
|
136 |
-
def script_split(script_in, chunk_size
|
137 |
script_split = script_in.split('\n\n')
|
138 |
script_arr = []
|
139 |
range_arr = []
|
@@ -143,20 +154,21 @@ def script_split(script_in, chunk_size = 1000):
|
|
143 |
for sentence in script_split:
|
144 |
if len(script) + len(sentence) + 1 <= chunk_size:
|
145 |
script += sentence + '\n\n'
|
146 |
-
end+=1
|
147 |
else:
|
148 |
range_arr.append((start, end))
|
149 |
-
start = end+1
|
150 |
end += 1
|
151 |
script_arr.append(script.strip())
|
152 |
script = sentence + '\n\n'
|
153 |
if script.strip():
|
154 |
script_arr.append(script.strip())
|
155 |
-
range_arr.append((start, len(script_split)-1))
|
156 |
|
157 |
assert len(script_arr) == len(range_arr)
|
158 |
return script_arr, range_arr
|
159 |
|
|
|
160 |
def check_translation(sentence, translation):
|
161 |
"""
|
162 |
check merge sentence issue from openai translation
|
@@ -187,24 +199,25 @@ def get_response(model_name, sentence):
|
|
187 |
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
|
188 |
response = openai.ChatCompletion.create(
|
189 |
model=model_name,
|
190 |
-
messages
|
191 |
-
#{"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
|
192 |
-
#{"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
|
193 |
-
#{"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
|
194 |
-
#{"role": "system", "content": "There is no need for you to add any comments or notes."},
|
195 |
-
#{"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
|
196 |
-
|
197 |
-
{"role": "system",
|
|
|
198 |
{"role": "user", "content": sentence}
|
199 |
],
|
200 |
temperature=0.15
|
201 |
)
|
202 |
|
203 |
return response['choices'][0]['message']['content'].strip()
|
204 |
-
|
205 |
-
|
206 |
# Translate and save
|
207 |
-
def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count
|
208 |
"""
|
209 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
210 |
|
@@ -226,7 +239,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
226 |
previous_length = 0
|
227 |
for sentence, range in tqdm(zip(script_arr, range_arr)):
|
228 |
# update the range based on previous length
|
229 |
-
range = (range[0]+previous_length, range[1]+previous_length)
|
230 |
|
231 |
# using chatgpt model
|
232 |
print(f"now translating sentences {range}")
|
@@ -240,7 +253,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
240 |
while not check_translation(sentence, translate) and attempts_count > 0:
|
241 |
translate = get_response(model_name, sentence)
|
242 |
attempts_count -= 1
|
243 |
-
|
244 |
# if failure still happen, split into smaller tokens
|
245 |
if attempts_count == 0:
|
246 |
single_sentences = sentence.split("\n\n")
|
@@ -252,11 +265,11 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
252 |
else:
|
253 |
translate += get_response(model_name, single_sentence) + "\n\n"
|
254 |
# print(single_sentence, translate.split("\n\n")[-2])
|
255 |
-
logging.info("solved by individually translation!")
|
256 |
|
257 |
except Exception as e:
|
258 |
-
logging.debug("An error has occurred during translation:",e)
|
259 |
-
print("An error has occurred during translation:",e)
|
260 |
print("Retrying... the script will continue after 30 seconds.")
|
261 |
time.sleep(30)
|
262 |
flag = True
|
@@ -284,9 +297,9 @@ def main():
|
|
284 |
RESULT_PATH = args.output_dir
|
285 |
if not os.path.exists(RESULT_PATH):
|
286 |
os.mkdir(RESULT_PATH)
|
287 |
-
|
288 |
# set video name as the input file name if not specified
|
289 |
-
if args.video_name == 'placeholder'
|
290 |
# set video name to upload file name
|
291 |
if args.video_file is not None:
|
292 |
VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
|
@@ -303,7 +316,9 @@ def main():
|
|
303 |
|
304 |
if not os.path.exists(args.log_dir):
|
305 |
os.makedirs(args.log_dir)
|
306 |
-
logging.basicConfig(level=logging.INFO, handlers=[
|
|
|
|
|
307 |
logging.info("---------------------Video Info---------------------")
|
308 |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
|
309 |
|
@@ -346,12 +361,16 @@ def main():
|
|
346 |
if args.v:
|
347 |
logging.info("encoding video file")
|
348 |
if args.only_srt:
|
349 |
-
os.system(
|
|
|
350 |
else:
|
351 |
-
os.system(
|
|
|
352 |
|
353 |
end_time = time.time()
|
354 |
-
logging.info(
|
|
|
|
|
355 |
|
356 |
if __name__ == "__main__":
|
357 |
-
main()
|
|
|
3 |
import argparse
|
4 |
import os
|
5 |
from tqdm import tqdm
|
6 |
+
from srt_util.srt import SrtScript
|
7 |
import stable_whisper
|
8 |
import whisper
|
9 |
+
from srt_util.srt2ass import srt2ass
|
10 |
import logging
|
11 |
from datetime import datetime
|
12 |
import torch
|
|
|
15 |
|
16 |
import time
|
17 |
|
18 |
+
|
19 |
def parse_args():
|
20 |
parser = argparse.ArgumentParser()
|
21 |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
22 |
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
|
23 |
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
|
24 |
+
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str,
|
25 |
+
required=False) # New argument
|
26 |
parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
|
27 |
parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
|
28 |
+
parser.add_argument("--video_name",
|
29 |
+
help="video name, if use video link as input, the name will auto-filled by youtube video name",
|
30 |
+
default='placeholder', type=str, required=False)
|
31 |
+
parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str,
|
32 |
+
required=False, default="gpt-4") # default change to gpt-4
|
33 |
parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
|
34 |
parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
|
35 |
parser.add_argument("-v", help="auto encode script with video", action='store_true')
|
36 |
args = parser.parse_args()
|
37 |
+
|
38 |
return args
|
39 |
|
40 |
+
|
41 |
def get_sources(args, download_path, result_path, video_name):
|
42 |
# get source audio
|
43 |
audio_path = None
|
|
|
65 |
print("Error: Audio stream not found")
|
66 |
except Exception as e:
|
67 |
print("Connection Error")
|
68 |
+
print(e)
|
69 |
exit()
|
70 |
+
|
71 |
video_path = f'{download_path}/video/{video.default_filename}'
|
72 |
audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
|
73 |
audio_file = open(audio_path, "rb")
|
|
|
78 |
video_path = args.video_file
|
79 |
|
80 |
if args.audio_file is not None:
|
81 |
+
audio_file = open(args.audio_file, "rb")
|
82 |
audio_path = args.audio_file
|
83 |
else:
|
84 |
output_audio_path = f'{download_path}/audio/{video_name}.mp3'
|
|
|
90 |
os.mkdir(f'{result_path}/{video_name}')
|
91 |
|
92 |
if args.audio_file is not None:
|
93 |
+
audio_file = open(args.audio_file, "rb")
|
94 |
audio_path = args.audio_file
|
95 |
pass
|
96 |
|
97 |
return audio_path, audio_file, video_path, video_name
|
98 |
|
99 |
+
|
100 |
+
def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file=None, whisper_model='large',
|
101 |
+
method="stable"):
|
102 |
# Instead of using the script_en variable directly, we'll use script_input
|
103 |
+
if srt_file_en is not None:
|
104 |
+
srt = SrtScript.parse_from_srt_file(srt_file_en)
|
105 |
else:
|
106 |
# using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
|
107 |
srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
|
108 |
if not os.path.exists(srt_file_en):
|
109 |
+
|
110 |
+
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
111 |
# use OpenAI API for transcribe
|
112 |
if method == "api":
|
113 |
+
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
114 |
|
115 |
+
# use local whisper model
|
116 |
elif method == "basic":
|
117 |
+
model = whisper.load_model(whisper_model,
|
118 |
+
device=devices) # using base model in local machine (may use large model on our server)
|
119 |
transcript = model.transcribe(audio_path)
|
120 |
|
121 |
# use stable-whisper
|
122 |
elif method == "stable":
|
123 |
|
124 |
# use cuda if available
|
125 |
+
model = stable_whisper.load_model(whisper_model, device=devices)
|
126 |
+
transcript = model.transcribe(audio_path, regroup=False,
|
127 |
+
initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
|
128 |
(
|
129 |
transcript
|
130 |
.split_by_punctuation(['.', '。', '?'])
|
|
|
136 |
else:
|
137 |
raise ValueError("invalid speech to text method")
|
138 |
|
139 |
+
srt = SrtScript(transcript['segments']) # read segments to SRT class
|
140 |
|
141 |
else:
|
142 |
+
srt = SrtScript.parse_from_srt_file(srt_file_en)
|
143 |
return srt_file_en, srt
|
144 |
|
145 |
+
|
146 |
# Split the video script by sentences and create chunks within the token limit
|
147 |
+
def script_split(script_in, chunk_size=1000):
|
148 |
script_split = script_in.split('\n\n')
|
149 |
script_arr = []
|
150 |
range_arr = []
|
|
|
154 |
for sentence in script_split:
|
155 |
if len(script) + len(sentence) + 1 <= chunk_size:
|
156 |
script += sentence + '\n\n'
|
157 |
+
end += 1
|
158 |
else:
|
159 |
range_arr.append((start, end))
|
160 |
+
start = end + 1
|
161 |
end += 1
|
162 |
script_arr.append(script.strip())
|
163 |
script = sentence + '\n\n'
|
164 |
if script.strip():
|
165 |
script_arr.append(script.strip())
|
166 |
+
range_arr.append((start, len(script_split) - 1))
|
167 |
|
168 |
assert len(script_arr) == len(range_arr)
|
169 |
return script_arr, range_arr
|
170 |
|
171 |
+
|
172 |
def check_translation(sentence, translation):
|
173 |
"""
|
174 |
check merge sentence issue from openai translation
|
|
|
199 |
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
|
200 |
response = openai.ChatCompletion.create(
|
201 |
model=model_name,
|
202 |
+
messages=[
|
203 |
+
# {"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
|
204 |
+
# {"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
|
205 |
+
# {"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
|
206 |
+
# {"role": "system", "content": "There is no need for you to add any comments or notes."},
|
207 |
+
# {"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
|
208 |
+
|
209 |
+
{"role": "system",
|
210 |
+
"content": "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"},
|
211 |
{"role": "user", "content": sentence}
|
212 |
],
|
213 |
temperature=0.15
|
214 |
)
|
215 |
|
216 |
return response['choices'][0]['message']['content'].strip()
|
217 |
+
|
218 |
+
|
219 |
# Translate and save
|
220 |
+
def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count=5):
|
221 |
"""
|
222 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
223 |
|
|
|
239 |
previous_length = 0
|
240 |
for sentence, range in tqdm(zip(script_arr, range_arr)):
|
241 |
# update the range based on previous length
|
242 |
+
range = (range[0] + previous_length, range[1] + previous_length)
|
243 |
|
244 |
# using chatgpt model
|
245 |
print(f"now translating sentences {range}")
|
|
|
253 |
while not check_translation(sentence, translate) and attempts_count > 0:
|
254 |
translate = get_response(model_name, sentence)
|
255 |
attempts_count -= 1
|
256 |
+
|
257 |
# if failure still happen, split into smaller tokens
|
258 |
if attempts_count == 0:
|
259 |
single_sentences = sentence.split("\n\n")
|
|
|
265 |
else:
|
266 |
translate += get_response(model_name, single_sentence) + "\n\n"
|
267 |
# print(single_sentence, translate.split("\n\n")[-2])
|
268 |
+
logging.info("solved by individually translation!")
|
269 |
|
270 |
except Exception as e:
|
271 |
+
logging.debug("An error has occurred during translation:", e)
|
272 |
+
print("An error has occurred during translation:", e)
|
273 |
print("Retrying... the script will continue after 30 seconds.")
|
274 |
time.sleep(30)
|
275 |
flag = True
|
|
|
297 |
RESULT_PATH = args.output_dir
|
298 |
if not os.path.exists(RESULT_PATH):
|
299 |
os.mkdir(RESULT_PATH)
|
300 |
+
|
301 |
# set video name as the input file name if not specified
|
302 |
+
if args.video_name == 'placeholder':
|
303 |
# set video name to upload file name
|
304 |
if args.video_file is not None:
|
305 |
VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
|
|
|
316 |
|
317 |
if not os.path.exists(args.log_dir):
|
318 |
os.makedirs(args.log_dir)
|
319 |
+
logging.basicConfig(level=logging.INFO, handlers=[
|
320 |
+
logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")),
|
321 |
+
'w', encoding='utf-8')])
|
322 |
logging.info("---------------------Video Info---------------------")
|
323 |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
|
324 |
|
|
|
361 |
if args.v:
|
362 |
logging.info("encoding video file")
|
363 |
if args.only_srt:
|
364 |
+
os.system(
|
365 |
+
f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
|
366 |
else:
|
367 |
+
os.system(
|
368 |
+
f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
|
369 |
|
370 |
end_time = time.time()
|
371 |
+
logging.info(
|
372 |
+
"Pipeline finished, time duration:{}".format(time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))))
|
373 |
+
|
374 |
|
375 |
if __name__ == "__main__":
|
376 |
+
main()
|
srt_util/__init__.py
ADDED
File without changes
|
SRT.py → srt_util/srt.py
RENAMED
@@ -8,7 +8,7 @@ import openai
|
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
|
11 |
-
class
|
12 |
def __init__(self, *args) -> None:
|
13 |
if isinstance(args[0], dict):
|
14 |
segment = args[0]
|
@@ -64,28 +64,23 @@ class SRT_segment(object):
|
|
64 |
self.end = seg.end
|
65 |
self.end_ms = seg.end_ms
|
66 |
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
|
67 |
-
pass
|
68 |
|
69 |
def __add__(self, other):
|
70 |
"""
|
71 |
Merge the segment seg with the current segment, and return the new constructed segment.
|
72 |
No in-place modification.
|
|
|
73 |
:param other: Another segment that is strictly next to added segment.
|
74 |
:return: new segment of the two sub-segments
|
75 |
"""
|
76 |
|
77 |
result = deepcopy(self)
|
78 |
-
result.
|
79 |
-
result.translation += f' {other.translation}'
|
80 |
-
result.end_time_str = other.end_time_str
|
81 |
-
result.end = other.end
|
82 |
-
result.end_ms = other.end_ms
|
83 |
-
result.duration = f"{self.start_time_str} --> {result.end_time_str}"
|
84 |
return result
|
85 |
|
86 |
-
def remove_trans_punc(self):
|
87 |
"""
|
88 |
-
remove punctuations in translation text
|
89 |
:return: None
|
90 |
"""
|
91 |
punc_cn = ",。!?"
|
@@ -102,12 +97,9 @@ class SRT_segment(object):
|
|
102 |
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
|
103 |
|
104 |
|
105 |
-
class
|
106 |
def __init__(self, segments) -> None:
|
107 |
-
self.segments = []
|
108 |
-
for seg in segments:
|
109 |
-
srt_seg = SRT_segment(seg)
|
110 |
-
self.segments.append(srt_seg)
|
111 |
|
112 |
@classmethod
|
113 |
def parse_from_srt_file(cls, path: str):
|
@@ -115,13 +107,12 @@ class SRT_script():
|
|
115 |
script_lines = [line.rstrip() for line in f.readlines()]
|
116 |
|
117 |
segments = []
|
118 |
-
for i in range(len(script_lines)
|
119 |
-
|
120 |
-
segments.append(list(script_lines[i:i + 4]))
|
121 |
|
122 |
return cls(segments)
|
123 |
|
124 |
-
def merge_segs(self, idx_list) ->
|
125 |
"""
|
126 |
Merge entire segment list to a single segment
|
127 |
:param idx_list: List of index to merge
|
@@ -147,6 +138,7 @@ class SRT_script():
|
|
147 |
logging.info("Forming whole sentences...")
|
148 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
149 |
sentence = []
|
|
|
150 |
for i, seg in enumerate(self.segments):
|
151 |
if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
|
152 |
sentence.append(i)
|
@@ -155,6 +147,7 @@ class SRT_script():
|
|
155 |
else:
|
156 |
sentence.append(i)
|
157 |
|
|
|
158 |
segments = []
|
159 |
for idx_list in merge_list:
|
160 |
if len(idx_list) > 1:
|
@@ -254,11 +247,10 @@ class SRT_script():
|
|
254 |
max_num -= 1
|
255 |
if i == len(lines) - 1:
|
256 |
break
|
257 |
-
if lines[i][0] in [' ', '\n']:
|
258 |
lines[i] = lines[i][1:]
|
259 |
seg.translation = lines[i]
|
260 |
|
261 |
-
|
262 |
def split_seg(self, seg, text_threshold, time_threshold):
|
263 |
# evenly split seg to 2 parts and add new seg into self.segments
|
264 |
|
@@ -314,14 +306,14 @@ class SRT_script():
|
|
314 |
seg1_dict['text'] = src_seg1
|
315 |
seg1_dict['start'] = start_seg1
|
316 |
seg1_dict['end'] = end_seg1
|
317 |
-
seg1 =
|
318 |
seg1.translation = trans_seg1
|
319 |
|
320 |
seg2_dict = {}
|
321 |
seg2_dict['text'] = src_seg2
|
322 |
seg2_dict['start'] = start_seg2
|
323 |
seg2_dict['end'] = end_seg2
|
324 |
-
seg2 =
|
325 |
seg2.translation = trans_seg2
|
326 |
|
327 |
result_list = []
|
@@ -344,7 +336,7 @@ class SRT_script():
|
|
344 |
for i, seg in enumerate(self.segments):
|
345 |
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
|
346 |
seg_list = self.split_seg(seg, text_threshold, time_threshold)
|
347 |
-
logging.info("splitting segment {} in to {} parts".format(i+1, len(seg_list)))
|
348 |
segments += seg_list
|
349 |
else:
|
350 |
segments.append(seg)
|
@@ -376,39 +368,41 @@ class SRT_script():
|
|
376 |
## force term correction
|
377 |
logging.info("performing force term correction")
|
378 |
# load term dictionary
|
379 |
-
with open("
|
380 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
381 |
-
|
382 |
keywords = list(term_enzh_dict.keys())
|
383 |
keywords.sort(key=lambda x: len(x), reverse=True)
|
384 |
|
385 |
for word in keywords:
|
386 |
for i, seg in enumerate(self.segments):
|
387 |
if word in seg.source_text.lower():
|
388 |
-
seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)),
|
389 |
-
|
|
|
|
|
|
|
390 |
logging.info("source text becomes: " + seg.source_text)
|
391 |
-
|
392 |
-
|
393 |
comp_dict = []
|
394 |
-
|
395 |
-
def fetchfunc(self,word,threshold):
|
396 |
import enchant
|
397 |
result = word
|
398 |
distance = 0
|
399 |
-
threshold = threshold*len(word)
|
400 |
-
if len(self.comp_dict)==0:
|
401 |
with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
|
402 |
-
|
403 |
temp = ""
|
404 |
for matched in self.comp_dict:
|
405 |
if (" " in matched and " " in word) or (" " not in matched and " " not in word):
|
406 |
-
if enchant.utils.levenshtein(word, matched)<enchant.utils.levenshtein(word, temp):
|
407 |
temp = matched
|
408 |
if enchant.utils.levenshtein(word, temp) < threshold:
|
409 |
distance = enchant.utils.levenshtein(word, temp)
|
410 |
result = temp
|
411 |
-
return distance, result
|
412 |
|
413 |
def extract_words(self, sentence, n):
|
414 |
# this function split the sentence to chunks by n of words
|
@@ -417,9 +411,9 @@ class SRT_script():
|
|
417 |
words = sentence.split()
|
418 |
res = []
|
419 |
for j in range(n, 0, -1):
|
420 |
-
res += [words[i:i+j] for i in range(len(words)-j+1)]
|
421 |
-
return
|
422 |
-
|
423 |
def spell_check_term(self):
|
424 |
logging.info("performing spell check")
|
425 |
import enchant
|
@@ -435,14 +429,14 @@ class SRT_script():
|
|
435 |
distance, correct_term = self.fetchfunc(real_word, 0.3)
|
436 |
if distance != 0:
|
437 |
seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
|
438 |
-
logging.info(
|
|
|
439 |
|
440 |
-
|
441 |
-
def get_real_word(self, word_list:list):
|
442 |
word = ""
|
443 |
for w in word_list:
|
444 |
word += f"{w} "
|
445 |
-
word = word[:-1]
|
446 |
if word[-2:] == ".\n":
|
447 |
real_word = word[:-2].lower()
|
448 |
n = -2
|
@@ -460,8 +454,8 @@ class SRT_script():
|
|
460 |
# return a string with pure source text
|
461 |
result = ""
|
462 |
for i, seg in enumerate(self.segments):
|
463 |
-
result+=f'{seg.source_text}\n\n\n'#f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
|
464 |
-
|
465 |
return result
|
466 |
|
467 |
def reform_src_str(self):
|
|
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
|
11 |
+
class SrtSegment(object):
|
12 |
def __init__(self, *args) -> None:
|
13 |
if isinstance(args[0], dict):
|
14 |
segment = args[0]
|
|
|
64 |
self.end = seg.end
|
65 |
self.end_ms = seg.end_ms
|
66 |
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
|
|
|
67 |
|
68 |
def __add__(self, other):
|
69 |
"""
|
70 |
Merge the segment seg with the current segment, and return the new constructed segment.
|
71 |
No in-place modification.
|
72 |
+
This is used for '+' operator.
|
73 |
:param other: Another segment that is strictly next to added segment.
|
74 |
:return: new segment of the two sub-segments
|
75 |
"""
|
76 |
|
77 |
result = deepcopy(self)
|
78 |
+
result.merge_seg(other)
|
|
|
|
|
|
|
|
|
|
|
79 |
return result
|
80 |
|
81 |
+
def remove_trans_punc(self) -> None:
|
82 |
"""
|
83 |
+
remove CN punctuations in translation text
|
84 |
:return: None
|
85 |
"""
|
86 |
punc_cn = ",。!?"
|
|
|
97 |
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
|
98 |
|
99 |
|
100 |
+
class SrtScript(object):
|
101 |
def __init__(self, segments) -> None:
|
102 |
+
self.segments = [SrtSegment(seg) for seg in segments]
|
|
|
|
|
|
|
103 |
|
104 |
@classmethod
|
105 |
def parse_from_srt_file(cls, path: str):
|
|
|
107 |
script_lines = [line.rstrip() for line in f.readlines()]
|
108 |
|
109 |
segments = []
|
110 |
+
for i in range(0, len(script_lines), 4):
|
111 |
+
segments.append(list(script_lines[i:i + 4]))
|
|
|
112 |
|
113 |
return cls(segments)
|
114 |
|
115 |
+
def merge_segs(self, idx_list) -> SrtSegment:
|
116 |
"""
|
117 |
Merge entire segment list to a single segment
|
118 |
:param idx_list: List of index to merge
|
|
|
138 |
logging.info("Forming whole sentences...")
|
139 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
140 |
sentence = []
|
141 |
+
# Get each entire sentence of distinct segments, fill indices to merge_list
|
142 |
for i, seg in enumerate(self.segments):
|
143 |
if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
|
144 |
sentence.append(i)
|
|
|
147 |
else:
|
148 |
sentence.append(i)
|
149 |
|
150 |
+
# Reconstruct segments, each with an entire sentence
|
151 |
segments = []
|
152 |
for idx_list in merge_list:
|
153 |
if len(idx_list) > 1:
|
|
|
247 |
max_num -= 1
|
248 |
if i == len(lines) - 1:
|
249 |
break
|
250 |
+
if lines[i][0] in [' ', '\n']:
|
251 |
lines[i] = lines[i][1:]
|
252 |
seg.translation = lines[i]
|
253 |
|
|
|
254 |
def split_seg(self, seg, text_threshold, time_threshold):
|
255 |
# evenly split seg to 2 parts and add new seg into self.segments
|
256 |
|
|
|
306 |
seg1_dict['text'] = src_seg1
|
307 |
seg1_dict['start'] = start_seg1
|
308 |
seg1_dict['end'] = end_seg1
|
309 |
+
seg1 = SrtSegment(seg1_dict)
|
310 |
seg1.translation = trans_seg1
|
311 |
|
312 |
seg2_dict = {}
|
313 |
seg2_dict['text'] = src_seg2
|
314 |
seg2_dict['start'] = start_seg2
|
315 |
seg2_dict['end'] = end_seg2
|
316 |
+
seg2 = SrtSegment(seg2_dict)
|
317 |
seg2.translation = trans_seg2
|
318 |
|
319 |
result_list = []
|
|
|
336 |
for i, seg in enumerate(self.segments):
|
337 |
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
|
338 |
seg_list = self.split_seg(seg, text_threshold, time_threshold)
|
339 |
+
logging.info("splitting segment {} in to {} parts".format(i + 1, len(seg_list)))
|
340 |
segments += seg_list
|
341 |
else:
|
342 |
segments.append(seg)
|
|
|
368 |
## force term correction
|
369 |
logging.info("performing force term correction")
|
370 |
# load term dictionary
|
371 |
+
with open("../finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
372 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
373 |
+
|
374 |
keywords = list(term_enzh_dict.keys())
|
375 |
keywords.sort(key=lambda x: len(x), reverse=True)
|
376 |
|
377 |
for word in keywords:
|
378 |
for i, seg in enumerate(self.segments):
|
379 |
if word in seg.source_text.lower():
|
380 |
+
seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)),
|
381 |
+
seg.source_text, flags=re.IGNORECASE)
|
382 |
+
logging.info(
|
383 |
+
"replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(
|
384 |
+
i + 1))
|
385 |
logging.info("source text becomes: " + seg.source_text)
|
386 |
+
|
|
|
387 |
comp_dict = []
|
388 |
+
|
389 |
+
def fetchfunc(self, word, threshold):
|
390 |
import enchant
|
391 |
result = word
|
392 |
distance = 0
|
393 |
+
threshold = threshold * len(word)
|
394 |
+
if len(self.comp_dict) == 0:
|
395 |
with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
|
396 |
+
self.comp_dict = {rows[0]: 1 for rows in reader(f)}
|
397 |
temp = ""
|
398 |
for matched in self.comp_dict:
|
399 |
if (" " in matched and " " in word) or (" " not in matched and " " not in word):
|
400 |
+
if enchant.utils.levenshtein(word, matched) < enchant.utils.levenshtein(word, temp):
|
401 |
temp = matched
|
402 |
if enchant.utils.levenshtein(word, temp) < threshold:
|
403 |
distance = enchant.utils.levenshtein(word, temp)
|
404 |
result = temp
|
405 |
+
return distance, result
|
406 |
|
407 |
def extract_words(self, sentence, n):
|
408 |
# this function split the sentence to chunks by n of words
|
|
|
411 |
words = sentence.split()
|
412 |
res = []
|
413 |
for j in range(n, 0, -1):
|
414 |
+
res += [words[i:i + j] for i in range(len(words) - j + 1)]
|
415 |
+
return res
|
416 |
+
|
417 |
def spell_check_term(self):
|
418 |
logging.info("performing spell check")
|
419 |
import enchant
|
|
|
429 |
distance, correct_term = self.fetchfunc(real_word, 0.3)
|
430 |
if distance != 0:
|
431 |
seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
|
432 |
+
logging.info(
|
433 |
+
"replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance))
|
434 |
|
435 |
+
def get_real_word(self, word_list: list):
|
|
|
436 |
word = ""
|
437 |
for w in word_list:
|
438 |
word += f"{w} "
|
439 |
+
word = word[:-1] # "this, is"
|
440 |
if word[-2:] == ".\n":
|
441 |
real_word = word[:-2].lower()
|
442 |
n = -2
|
|
|
454 |
# return a string with pure source text
|
455 |
result = ""
|
456 |
for i, seg in enumerate(self.segments):
|
457 |
+
result += f'{seg.source_text}\n\n\n' # f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
|
458 |
+
|
459 |
return result
|
460 |
|
461 |
def reform_src_str(self):
|
srt2ass.py → srt_util/srt2ass.py
RENAMED
File without changes
|