Spaces:
Sleeping
Sleeping
Merge branch 'SRT_cleanup' into eason/main
Browse filesFormer-commit-id: 3cedf7bb4e826122d3227968510ee9811a86bcb5
- doc/Installation.md +7 -0
- doc/struct.md +7 -0
- pipeline.py +68 -49
- srt_util/__init__.py +0 -0
- SRT.py → srt_util/srt.py +40 -46
- srt2ass.py → srt_util/srt2ass.py +0 -0
doc/Installation.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### **Recommended:**
|
| 2 |
+
We recommend you to configure your environment using [mamba](https://pypi.org/project/mamba/). The following packages are required:
|
| 3 |
+
```
|
| 4 |
+
openai
|
| 5 |
+
openai-whisper
|
| 6 |
+
|
| 7 |
+
```
|
doc/struct.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Structure of Repository
|
| 2 |
+
```
|
| 3 |
+
├── doc # Baseline implementation of SpMM algorithm.
|
| 4 |
+
├────── struct.md # Document of repository structure.
|
| 5 |
+
├── finetune_data #
|
| 6 |
+
└── README.md
|
| 7 |
+
```
|
pipeline.py
CHANGED
|
@@ -3,10 +3,10 @@ from pytube import YouTube
|
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
from tqdm import tqdm
|
| 6 |
-
from
|
| 7 |
import stable_whisper
|
| 8 |
import whisper
|
| 9 |
-
from srt2ass import srt2ass
|
| 10 |
import logging
|
| 11 |
from datetime import datetime
|
| 12 |
import torch
|
|
@@ -15,23 +15,29 @@ import subprocess
|
|
| 15 |
|
| 16 |
import time
|
| 17 |
|
|
|
|
| 18 |
def parse_args():
|
| 19 |
parser = argparse.ArgumentParser()
|
| 20 |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
| 21 |
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
|
| 22 |
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
|
| 23 |
-
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str,
|
|
|
|
| 24 |
parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
|
| 25 |
parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
|
| 26 |
-
parser.add_argument("--video_name",
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
|
| 29 |
parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
|
| 30 |
parser.add_argument("-v", help="auto encode script with video", action='store_true')
|
| 31 |
args = parser.parse_args()
|
| 32 |
-
|
| 33 |
return args
|
| 34 |
|
|
|
|
| 35 |
def get_sources(args, download_path, result_path, video_name):
|
| 36 |
# get source audio
|
| 37 |
audio_path = None
|
|
@@ -59,9 +65,9 @@ def get_sources(args, download_path, result_path, video_name):
|
|
| 59 |
print("Error: Audio stream not found")
|
| 60 |
except Exception as e:
|
| 61 |
print("Connection Error")
|
| 62 |
-
print(e)
|
| 63 |
exit()
|
| 64 |
-
|
| 65 |
video_path = f'{download_path}/video/{video.default_filename}'
|
| 66 |
audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
|
| 67 |
audio_file = open(audio_path, "rb")
|
|
@@ -72,7 +78,7 @@ def get_sources(args, download_path, result_path, video_name):
|
|
| 72 |
video_path = args.video_file
|
| 73 |
|
| 74 |
if args.audio_file is not None:
|
| 75 |
-
audio_file= open(args.audio_file, "rb")
|
| 76 |
audio_path = args.audio_file
|
| 77 |
else:
|
| 78 |
output_audio_path = f'{download_path}/audio/{video_name}.mp3'
|
|
@@ -84,37 +90,41 @@ def get_sources(args, download_path, result_path, video_name):
|
|
| 84 |
os.mkdir(f'{result_path}/{video_name}')
|
| 85 |
|
| 86 |
if args.audio_file is not None:
|
| 87 |
-
audio_file= open(args.audio_file, "rb")
|
| 88 |
audio_path = args.audio_file
|
| 89 |
pass
|
| 90 |
|
| 91 |
return audio_path, audio_file, video_path, video_name
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
# Instead of using the script_en variable directly, we'll use script_input
|
| 95 |
-
if srt_file_en is not None:
|
| 96 |
-
srt =
|
| 97 |
else:
|
| 98 |
# using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
|
| 99 |
srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
|
| 100 |
if not os.path.exists(srt_file_en):
|
| 101 |
-
|
| 102 |
-
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 103 |
# use OpenAI API for transcribe
|
| 104 |
if method == "api":
|
| 105 |
-
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
| 106 |
|
| 107 |
-
|
| 108 |
elif method == "basic":
|
| 109 |
-
model = whisper.load_model(whisper_model,
|
|
|
|
| 110 |
transcript = model.transcribe(audio_path)
|
| 111 |
|
| 112 |
# use stable-whisper
|
| 113 |
elif method == "stable":
|
| 114 |
|
| 115 |
# use cuda if available
|
| 116 |
-
model = stable_whisper.load_model(whisper_model, device
|
| 117 |
-
transcript = model.transcribe(audio_path, regroup
|
|
|
|
| 118 |
(
|
| 119 |
transcript
|
| 120 |
.split_by_punctuation(['.', '。', '?'])
|
|
@@ -126,14 +136,15 @@ def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file =
|
|
| 126 |
else:
|
| 127 |
raise ValueError("invalid speech to text method")
|
| 128 |
|
| 129 |
-
srt =
|
| 130 |
|
| 131 |
else:
|
| 132 |
-
srt =
|
| 133 |
return srt_file_en, srt
|
| 134 |
|
|
|
|
| 135 |
# Split the video script by sentences and create chunks within the token limit
|
| 136 |
-
def script_split(script_in, chunk_size
|
| 137 |
script_split = script_in.split('\n\n')
|
| 138 |
script_arr = []
|
| 139 |
range_arr = []
|
|
@@ -143,20 +154,21 @@ def script_split(script_in, chunk_size = 1000):
|
|
| 143 |
for sentence in script_split:
|
| 144 |
if len(script) + len(sentence) + 1 <= chunk_size:
|
| 145 |
script += sentence + '\n\n'
|
| 146 |
-
end+=1
|
| 147 |
else:
|
| 148 |
range_arr.append((start, end))
|
| 149 |
-
start = end+1
|
| 150 |
end += 1
|
| 151 |
script_arr.append(script.strip())
|
| 152 |
script = sentence + '\n\n'
|
| 153 |
if script.strip():
|
| 154 |
script_arr.append(script.strip())
|
| 155 |
-
range_arr.append((start, len(script_split)-1))
|
| 156 |
|
| 157 |
assert len(script_arr) == len(range_arr)
|
| 158 |
return script_arr, range_arr
|
| 159 |
|
|
|
|
| 160 |
def check_translation(sentence, translation):
|
| 161 |
"""
|
| 162 |
check merge sentence issue from openai translation
|
|
@@ -187,24 +199,25 @@ def get_response(model_name, sentence):
|
|
| 187 |
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
|
| 188 |
response = openai.ChatCompletion.create(
|
| 189 |
model=model_name,
|
| 190 |
-
messages
|
| 191 |
-
#{"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
|
| 192 |
-
#{"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
|
| 193 |
-
#{"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
|
| 194 |
-
#{"role": "system", "content": "There is no need for you to add any comments or notes."},
|
| 195 |
-
#{"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
|
| 196 |
-
|
| 197 |
-
{"role": "system",
|
|
|
|
| 198 |
{"role": "user", "content": sentence}
|
| 199 |
],
|
| 200 |
temperature=0.15
|
| 201 |
)
|
| 202 |
|
| 203 |
return response['choices'][0]['message']['content'].strip()
|
| 204 |
-
|
| 205 |
-
|
| 206 |
# Translate and save
|
| 207 |
-
def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count
|
| 208 |
"""
|
| 209 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
| 210 |
|
|
@@ -226,7 +239,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
| 226 |
previous_length = 0
|
| 227 |
for sentence, range in tqdm(zip(script_arr, range_arr)):
|
| 228 |
# update the range based on previous length
|
| 229 |
-
range = (range[0]+previous_length, range[1]+previous_length)
|
| 230 |
|
| 231 |
# using chatgpt model
|
| 232 |
print(f"now translating sentences {range}")
|
|
@@ -240,7 +253,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
| 240 |
while not check_translation(sentence, translate) and attempts_count > 0:
|
| 241 |
translate = get_response(model_name, sentence)
|
| 242 |
attempts_count -= 1
|
| 243 |
-
|
| 244 |
# if failure still happen, split into smaller tokens
|
| 245 |
if attempts_count == 0:
|
| 246 |
single_sentences = sentence.split("\n\n")
|
|
@@ -252,11 +265,11 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
|
|
| 252 |
else:
|
| 253 |
translate += get_response(model_name, single_sentence) + "\n\n"
|
| 254 |
# print(single_sentence, translate.split("\n\n")[-2])
|
| 255 |
-
logging.info("solved by individually translation!")
|
| 256 |
|
| 257 |
except Exception as e:
|
| 258 |
-
logging.debug("An error has occurred during translation:",e)
|
| 259 |
-
print("An error has occurred during translation:",e)
|
| 260 |
print("Retrying... the script will continue after 30 seconds.")
|
| 261 |
time.sleep(30)
|
| 262 |
flag = True
|
|
@@ -284,9 +297,9 @@ def main():
|
|
| 284 |
RESULT_PATH = args.output_dir
|
| 285 |
if not os.path.exists(RESULT_PATH):
|
| 286 |
os.mkdir(RESULT_PATH)
|
| 287 |
-
|
| 288 |
# set video name as the input file name if not specified
|
| 289 |
-
if args.video_name == 'placeholder'
|
| 290 |
# set video name to upload file name
|
| 291 |
if args.video_file is not None:
|
| 292 |
VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
|
|
@@ -303,7 +316,9 @@ def main():
|
|
| 303 |
|
| 304 |
if not os.path.exists(args.log_dir):
|
| 305 |
os.makedirs(args.log_dir)
|
| 306 |
-
logging.basicConfig(level=logging.INFO, handlers=[
|
|
|
|
|
|
|
| 307 |
logging.info("---------------------Video Info---------------------")
|
| 308 |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
|
| 309 |
|
|
@@ -346,12 +361,16 @@ def main():
|
|
| 346 |
if args.v:
|
| 347 |
logging.info("encoding video file")
|
| 348 |
if args.only_srt:
|
| 349 |
-
os.system(
|
|
|
|
| 350 |
else:
|
| 351 |
-
os.system(
|
|
|
|
| 352 |
|
| 353 |
end_time = time.time()
|
| 354 |
-
logging.info(
|
|
|
|
|
|
|
| 355 |
|
| 356 |
if __name__ == "__main__":
|
| 357 |
-
main()
|
|
|
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
from tqdm import tqdm
|
| 6 |
+
from srt_util.srt import SrtScript
|
| 7 |
import stable_whisper
|
| 8 |
import whisper
|
| 9 |
+
from srt_util.srt2ass import srt2ass
|
| 10 |
import logging
|
| 11 |
from datetime import datetime
|
| 12 |
import torch
|
|
|
|
| 15 |
|
| 16 |
import time
|
| 17 |
|
| 18 |
+
|
| 19 |
def parse_args():
|
| 20 |
parser = argparse.ArgumentParser()
|
| 21 |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
|
| 22 |
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
|
| 23 |
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
|
| 24 |
+
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str,
|
| 25 |
+
required=False) # New argument
|
| 26 |
parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
|
| 27 |
parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
|
| 28 |
+
parser.add_argument("--video_name",
|
| 29 |
+
help="video name, if use video link as input, the name will auto-filled by youtube video name",
|
| 30 |
+
default='placeholder', type=str, required=False)
|
| 31 |
+
parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str,
|
| 32 |
+
required=False, default="gpt-4") # default change to gpt-4
|
| 33 |
parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
|
| 34 |
parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
|
| 35 |
parser.add_argument("-v", help="auto encode script with video", action='store_true')
|
| 36 |
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
return args
|
| 39 |
|
| 40 |
+
|
| 41 |
def get_sources(args, download_path, result_path, video_name):
|
| 42 |
# get source audio
|
| 43 |
audio_path = None
|
|
|
|
| 65 |
print("Error: Audio stream not found")
|
| 66 |
except Exception as e:
|
| 67 |
print("Connection Error")
|
| 68 |
+
print(e)
|
| 69 |
exit()
|
| 70 |
+
|
| 71 |
video_path = f'{download_path}/video/{video.default_filename}'
|
| 72 |
audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
|
| 73 |
audio_file = open(audio_path, "rb")
|
|
|
|
| 78 |
video_path = args.video_file
|
| 79 |
|
| 80 |
if args.audio_file is not None:
|
| 81 |
+
audio_file = open(args.audio_file, "rb")
|
| 82 |
audio_path = args.audio_file
|
| 83 |
else:
|
| 84 |
output_audio_path = f'{download_path}/audio/{video_name}.mp3'
|
|
|
|
| 90 |
os.mkdir(f'{result_path}/{video_name}')
|
| 91 |
|
| 92 |
if args.audio_file is not None:
|
| 93 |
+
audio_file = open(args.audio_file, "rb")
|
| 94 |
audio_path = args.audio_file
|
| 95 |
pass
|
| 96 |
|
| 97 |
return audio_path, audio_file, video_path, video_name
|
| 98 |
|
| 99 |
+
|
| 100 |
+
def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file=None, whisper_model='large',
|
| 101 |
+
method="stable"):
|
| 102 |
# Instead of using the script_en variable directly, we'll use script_input
|
| 103 |
+
if srt_file_en is not None:
|
| 104 |
+
srt = SrtScript.parse_from_srt_file(srt_file_en)
|
| 105 |
else:
|
| 106 |
# using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
|
| 107 |
srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
|
| 108 |
if not os.path.exists(srt_file_en):
|
| 109 |
+
|
| 110 |
+
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 111 |
# use OpenAI API for transcribe
|
| 112 |
if method == "api":
|
| 113 |
+
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
| 114 |
|
| 115 |
+
# use local whisper model
|
| 116 |
elif method == "basic":
|
| 117 |
+
model = whisper.load_model(whisper_model,
|
| 118 |
+
device=devices) # using base model in local machine (may use large model on our server)
|
| 119 |
transcript = model.transcribe(audio_path)
|
| 120 |
|
| 121 |
# use stable-whisper
|
| 122 |
elif method == "stable":
|
| 123 |
|
| 124 |
# use cuda if available
|
| 125 |
+
model = stable_whisper.load_model(whisper_model, device=devices)
|
| 126 |
+
transcript = model.transcribe(audio_path, regroup=False,
|
| 127 |
+
initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
|
| 128 |
(
|
| 129 |
transcript
|
| 130 |
.split_by_punctuation(['.', '。', '?'])
|
|
|
|
| 136 |
else:
|
| 137 |
raise ValueError("invalid speech to text method")
|
| 138 |
|
| 139 |
+
srt = SrtScript(transcript['segments']) # read segments to SRT class
|
| 140 |
|
| 141 |
else:
|
| 142 |
+
srt = SrtScript.parse_from_srt_file(srt_file_en)
|
| 143 |
return srt_file_en, srt
|
| 144 |
|
| 145 |
+
|
| 146 |
# Split the video script by sentences and create chunks within the token limit
|
| 147 |
+
def script_split(script_in, chunk_size=1000):
|
| 148 |
script_split = script_in.split('\n\n')
|
| 149 |
script_arr = []
|
| 150 |
range_arr = []
|
|
|
|
| 154 |
for sentence in script_split:
|
| 155 |
if len(script) + len(sentence) + 1 <= chunk_size:
|
| 156 |
script += sentence + '\n\n'
|
| 157 |
+
end += 1
|
| 158 |
else:
|
| 159 |
range_arr.append((start, end))
|
| 160 |
+
start = end + 1
|
| 161 |
end += 1
|
| 162 |
script_arr.append(script.strip())
|
| 163 |
script = sentence + '\n\n'
|
| 164 |
if script.strip():
|
| 165 |
script_arr.append(script.strip())
|
| 166 |
+
range_arr.append((start, len(script_split) - 1))
|
| 167 |
|
| 168 |
assert len(script_arr) == len(range_arr)
|
| 169 |
return script_arr, range_arr
|
| 170 |
|
| 171 |
+
|
| 172 |
def check_translation(sentence, translation):
|
| 173 |
"""
|
| 174 |
check merge sentence issue from openai translation
|
|
|
|
| 199 |
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
|
| 200 |
response = openai.ChatCompletion.create(
|
| 201 |
model=model_name,
|
| 202 |
+
messages=[
|
| 203 |
+
# {"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
|
| 204 |
+
# {"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
|
| 205 |
+
# {"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
|
| 206 |
+
# {"role": "system", "content": "There is no need for you to add any comments or notes."},
|
| 207 |
+
# {"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
|
| 208 |
+
|
| 209 |
+
{"role": "system",
|
| 210 |
+
"content": "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"},
|
| 211 |
{"role": "user", "content": sentence}
|
| 212 |
],
|
| 213 |
temperature=0.15
|
| 214 |
)
|
| 215 |
|
| 216 |
return response['choices'][0]['message']['content'].strip()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
# Translate and save
|
| 220 |
+
def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count=5):
|
| 221 |
"""
|
| 222 |
Translates the given script array into another language using the chatgpt and writes to the SRT file.
|
| 223 |
|
|
|
|
| 239 |
previous_length = 0
|
| 240 |
for sentence, range in tqdm(zip(script_arr, range_arr)):
|
| 241 |
# update the range based on previous length
|
| 242 |
+
range = (range[0] + previous_length, range[1] + previous_length)
|
| 243 |
|
| 244 |
# using chatgpt model
|
| 245 |
print(f"now translating sentences {range}")
|
|
|
|
| 253 |
while not check_translation(sentence, translate) and attempts_count > 0:
|
| 254 |
translate = get_response(model_name, sentence)
|
| 255 |
attempts_count -= 1
|
| 256 |
+
|
| 257 |
# if failure still happen, split into smaller tokens
|
| 258 |
if attempts_count == 0:
|
| 259 |
single_sentences = sentence.split("\n\n")
|
|
|
|
| 265 |
else:
|
| 266 |
translate += get_response(model_name, single_sentence) + "\n\n"
|
| 267 |
# print(single_sentence, translate.split("\n\n")[-2])
|
| 268 |
+
logging.info("solved by individually translation!")
|
| 269 |
|
| 270 |
except Exception as e:
|
| 271 |
+
logging.debug("An error has occurred during translation:", e)
|
| 272 |
+
print("An error has occurred during translation:", e)
|
| 273 |
print("Retrying... the script will continue after 30 seconds.")
|
| 274 |
time.sleep(30)
|
| 275 |
flag = True
|
|
|
|
| 297 |
RESULT_PATH = args.output_dir
|
| 298 |
if not os.path.exists(RESULT_PATH):
|
| 299 |
os.mkdir(RESULT_PATH)
|
| 300 |
+
|
| 301 |
# set video name as the input file name if not specified
|
| 302 |
+
if args.video_name == 'placeholder':
|
| 303 |
# set video name to upload file name
|
| 304 |
if args.video_file is not None:
|
| 305 |
VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
|
|
|
|
| 316 |
|
| 317 |
if not os.path.exists(args.log_dir):
|
| 318 |
os.makedirs(args.log_dir)
|
| 319 |
+
logging.basicConfig(level=logging.INFO, handlers=[
|
| 320 |
+
logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")),
|
| 321 |
+
'w', encoding='utf-8')])
|
| 322 |
logging.info("---------------------Video Info---------------------")
|
| 323 |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
|
| 324 |
|
|
|
|
| 361 |
if args.v:
|
| 362 |
logging.info("encoding video file")
|
| 363 |
if args.only_srt:
|
| 364 |
+
os.system(
|
| 365 |
+
f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
|
| 366 |
else:
|
| 367 |
+
os.system(
|
| 368 |
+
f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
|
| 369 |
|
| 370 |
end_time = time.time()
|
| 371 |
+
logging.info(
|
| 372 |
+
"Pipeline finished, time duration:{}".format(time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))))
|
| 373 |
+
|
| 374 |
|
| 375 |
if __name__ == "__main__":
|
| 376 |
+
main()
|
srt_util/__init__.py
ADDED
|
File without changes
|
SRT.py → srt_util/srt.py
RENAMED
|
@@ -8,7 +8,7 @@ import openai
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
| 11 |
-
class
|
| 12 |
def __init__(self, *args) -> None:
|
| 13 |
if isinstance(args[0], dict):
|
| 14 |
segment = args[0]
|
|
@@ -64,28 +64,23 @@ class SRT_segment(object):
|
|
| 64 |
self.end = seg.end
|
| 65 |
self.end_ms = seg.end_ms
|
| 66 |
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
|
| 67 |
-
pass
|
| 68 |
|
| 69 |
def __add__(self, other):
|
| 70 |
"""
|
| 71 |
Merge the segment seg with the current segment, and return the new constructed segment.
|
| 72 |
No in-place modification.
|
|
|
|
| 73 |
:param other: Another segment that is strictly next to added segment.
|
| 74 |
:return: new segment of the two sub-segments
|
| 75 |
"""
|
| 76 |
|
| 77 |
result = deepcopy(self)
|
| 78 |
-
result.
|
| 79 |
-
result.translation += f' {other.translation}'
|
| 80 |
-
result.end_time_str = other.end_time_str
|
| 81 |
-
result.end = other.end
|
| 82 |
-
result.end_ms = other.end_ms
|
| 83 |
-
result.duration = f"{self.start_time_str} --> {result.end_time_str}"
|
| 84 |
return result
|
| 85 |
|
| 86 |
-
def remove_trans_punc(self):
|
| 87 |
"""
|
| 88 |
-
remove punctuations in translation text
|
| 89 |
:return: None
|
| 90 |
"""
|
| 91 |
punc_cn = ",。!?"
|
|
@@ -102,12 +97,9 @@ class SRT_segment(object):
|
|
| 102 |
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
|
| 103 |
|
| 104 |
|
| 105 |
-
class
|
| 106 |
def __init__(self, segments) -> None:
|
| 107 |
-
self.segments = []
|
| 108 |
-
for seg in segments:
|
| 109 |
-
srt_seg = SRT_segment(seg)
|
| 110 |
-
self.segments.append(srt_seg)
|
| 111 |
|
| 112 |
@classmethod
|
| 113 |
def parse_from_srt_file(cls, path: str):
|
|
@@ -115,13 +107,12 @@ class SRT_script():
|
|
| 115 |
script_lines = [line.rstrip() for line in f.readlines()]
|
| 116 |
|
| 117 |
segments = []
|
| 118 |
-
for i in range(len(script_lines)
|
| 119 |
-
|
| 120 |
-
segments.append(list(script_lines[i:i + 4]))
|
| 121 |
|
| 122 |
return cls(segments)
|
| 123 |
|
| 124 |
-
def merge_segs(self, idx_list) ->
|
| 125 |
"""
|
| 126 |
Merge entire segment list to a single segment
|
| 127 |
:param idx_list: List of index to merge
|
|
@@ -147,6 +138,7 @@ class SRT_script():
|
|
| 147 |
logging.info("Forming whole sentences...")
|
| 148 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
| 149 |
sentence = []
|
|
|
|
| 150 |
for i, seg in enumerate(self.segments):
|
| 151 |
if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
|
| 152 |
sentence.append(i)
|
|
@@ -155,6 +147,7 @@ class SRT_script():
|
|
| 155 |
else:
|
| 156 |
sentence.append(i)
|
| 157 |
|
|
|
|
| 158 |
segments = []
|
| 159 |
for idx_list in merge_list:
|
| 160 |
if len(idx_list) > 1:
|
|
@@ -254,11 +247,10 @@ class SRT_script():
|
|
| 254 |
max_num -= 1
|
| 255 |
if i == len(lines) - 1:
|
| 256 |
break
|
| 257 |
-
if lines[i][0] in [' ', '\n']:
|
| 258 |
lines[i] = lines[i][1:]
|
| 259 |
seg.translation = lines[i]
|
| 260 |
|
| 261 |
-
|
| 262 |
def split_seg(self, seg, text_threshold, time_threshold):
|
| 263 |
# evenly split seg to 2 parts and add new seg into self.segments
|
| 264 |
|
|
@@ -314,14 +306,14 @@ class SRT_script():
|
|
| 314 |
seg1_dict['text'] = src_seg1
|
| 315 |
seg1_dict['start'] = start_seg1
|
| 316 |
seg1_dict['end'] = end_seg1
|
| 317 |
-
seg1 =
|
| 318 |
seg1.translation = trans_seg1
|
| 319 |
|
| 320 |
seg2_dict = {}
|
| 321 |
seg2_dict['text'] = src_seg2
|
| 322 |
seg2_dict['start'] = start_seg2
|
| 323 |
seg2_dict['end'] = end_seg2
|
| 324 |
-
seg2 =
|
| 325 |
seg2.translation = trans_seg2
|
| 326 |
|
| 327 |
result_list = []
|
|
@@ -344,7 +336,7 @@ class SRT_script():
|
|
| 344 |
for i, seg in enumerate(self.segments):
|
| 345 |
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
|
| 346 |
seg_list = self.split_seg(seg, text_threshold, time_threshold)
|
| 347 |
-
logging.info("splitting segment {} in to {} parts".format(i+1, len(seg_list)))
|
| 348 |
segments += seg_list
|
| 349 |
else:
|
| 350 |
segments.append(seg)
|
|
@@ -376,39 +368,41 @@ class SRT_script():
|
|
| 376 |
## force term correction
|
| 377 |
logging.info("performing force term correction")
|
| 378 |
# load term dictionary
|
| 379 |
-
with open("
|
| 380 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
| 381 |
-
|
| 382 |
keywords = list(term_enzh_dict.keys())
|
| 383 |
keywords.sort(key=lambda x: len(x), reverse=True)
|
| 384 |
|
| 385 |
for word in keywords:
|
| 386 |
for i, seg in enumerate(self.segments):
|
| 387 |
if word in seg.source_text.lower():
|
| 388 |
-
seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)),
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
| 390 |
logging.info("source text becomes: " + seg.source_text)
|
| 391 |
-
|
| 392 |
-
|
| 393 |
comp_dict = []
|
| 394 |
-
|
| 395 |
-
def fetchfunc(self,word,threshold):
|
| 396 |
import enchant
|
| 397 |
result = word
|
| 398 |
distance = 0
|
| 399 |
-
threshold = threshold*len(word)
|
| 400 |
-
if len(self.comp_dict)==0:
|
| 401 |
with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
|
| 402 |
-
|
| 403 |
temp = ""
|
| 404 |
for matched in self.comp_dict:
|
| 405 |
if (" " in matched and " " in word) or (" " not in matched and " " not in word):
|
| 406 |
-
if enchant.utils.levenshtein(word, matched)<enchant.utils.levenshtein(word, temp):
|
| 407 |
temp = matched
|
| 408 |
if enchant.utils.levenshtein(word, temp) < threshold:
|
| 409 |
distance = enchant.utils.levenshtein(word, temp)
|
| 410 |
result = temp
|
| 411 |
-
return distance, result
|
| 412 |
|
| 413 |
def extract_words(self, sentence, n):
|
| 414 |
# this function split the sentence to chunks by n of words
|
|
@@ -417,9 +411,9 @@ class SRT_script():
|
|
| 417 |
words = sentence.split()
|
| 418 |
res = []
|
| 419 |
for j in range(n, 0, -1):
|
| 420 |
-
res += [words[i:i+j] for i in range(len(words)-j+1)]
|
| 421 |
-
return
|
| 422 |
-
|
| 423 |
def spell_check_term(self):
|
| 424 |
logging.info("performing spell check")
|
| 425 |
import enchant
|
|
@@ -435,14 +429,14 @@ class SRT_script():
|
|
| 435 |
distance, correct_term = self.fetchfunc(real_word, 0.3)
|
| 436 |
if distance != 0:
|
| 437 |
seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
|
| 438 |
-
logging.info(
|
|
|
|
| 439 |
|
| 440 |
-
|
| 441 |
-
def get_real_word(self, word_list:list):
|
| 442 |
word = ""
|
| 443 |
for w in word_list:
|
| 444 |
word += f"{w} "
|
| 445 |
-
word = word[:-1]
|
| 446 |
if word[-2:] == ".\n":
|
| 447 |
real_word = word[:-2].lower()
|
| 448 |
n = -2
|
|
@@ -460,8 +454,8 @@ class SRT_script():
|
|
| 460 |
# return a string with pure source text
|
| 461 |
result = ""
|
| 462 |
for i, seg in enumerate(self.segments):
|
| 463 |
-
result+=f'{seg.source_text}\n\n\n'#f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
|
| 464 |
-
|
| 465 |
return result
|
| 466 |
|
| 467 |
def reform_src_str(self):
|
|
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
| 11 |
+
class SrtSegment(object):
|
| 12 |
def __init__(self, *args) -> None:
|
| 13 |
if isinstance(args[0], dict):
|
| 14 |
segment = args[0]
|
|
|
|
| 64 |
self.end = seg.end
|
| 65 |
self.end_ms = seg.end_ms
|
| 66 |
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
|
|
|
|
| 67 |
|
| 68 |
def __add__(self, other):
|
| 69 |
"""
|
| 70 |
Merge the segment seg with the current segment, and return the new constructed segment.
|
| 71 |
No in-place modification.
|
| 72 |
+
This is used for '+' operator.
|
| 73 |
:param other: Another segment that is strictly next to added segment.
|
| 74 |
:return: new segment of the two sub-segments
|
| 75 |
"""
|
| 76 |
|
| 77 |
result = deepcopy(self)
|
| 78 |
+
result.merge_seg(other)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
return result
|
| 80 |
|
| 81 |
+
def remove_trans_punc(self) -> None:
|
| 82 |
"""
|
| 83 |
+
remove CN punctuations in translation text
|
| 84 |
:return: None
|
| 85 |
"""
|
| 86 |
punc_cn = ",。!?"
|
|
|
|
| 97 |
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
|
| 98 |
|
| 99 |
|
| 100 |
+
class SrtScript(object):
|
| 101 |
def __init__(self, segments) -> None:
|
| 102 |
+
self.segments = [SrtSegment(seg) for seg in segments]
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
@classmethod
|
| 105 |
def parse_from_srt_file(cls, path: str):
|
|
|
|
| 107 |
script_lines = [line.rstrip() for line in f.readlines()]
|
| 108 |
|
| 109 |
segments = []
|
| 110 |
+
for i in range(0, len(script_lines), 4):
|
| 111 |
+
segments.append(list(script_lines[i:i + 4]))
|
|
|
|
| 112 |
|
| 113 |
return cls(segments)
|
| 114 |
|
| 115 |
+
def merge_segs(self, idx_list) -> SrtSegment:
|
| 116 |
"""
|
| 117 |
Merge entire segment list to a single segment
|
| 118 |
:param idx_list: List of index to merge
|
|
|
|
| 138 |
logging.info("Forming whole sentences...")
|
| 139 |
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
|
| 140 |
sentence = []
|
| 141 |
+
# Get each entire sentence of distinct segments, fill indices to merge_list
|
| 142 |
for i, seg in enumerate(self.segments):
|
| 143 |
if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
|
| 144 |
sentence.append(i)
|
|
|
|
| 147 |
else:
|
| 148 |
sentence.append(i)
|
| 149 |
|
| 150 |
+
# Reconstruct segments, each with an entire sentence
|
| 151 |
segments = []
|
| 152 |
for idx_list in merge_list:
|
| 153 |
if len(idx_list) > 1:
|
|
|
|
| 247 |
max_num -= 1
|
| 248 |
if i == len(lines) - 1:
|
| 249 |
break
|
| 250 |
+
if lines[i][0] in [' ', '\n']:
|
| 251 |
lines[i] = lines[i][1:]
|
| 252 |
seg.translation = lines[i]
|
| 253 |
|
|
|
|
| 254 |
def split_seg(self, seg, text_threshold, time_threshold):
|
| 255 |
# evenly split seg to 2 parts and add new seg into self.segments
|
| 256 |
|
|
|
|
| 306 |
seg1_dict['text'] = src_seg1
|
| 307 |
seg1_dict['start'] = start_seg1
|
| 308 |
seg1_dict['end'] = end_seg1
|
| 309 |
+
seg1 = SrtSegment(seg1_dict)
|
| 310 |
seg1.translation = trans_seg1
|
| 311 |
|
| 312 |
seg2_dict = {}
|
| 313 |
seg2_dict['text'] = src_seg2
|
| 314 |
seg2_dict['start'] = start_seg2
|
| 315 |
seg2_dict['end'] = end_seg2
|
| 316 |
+
seg2 = SrtSegment(seg2_dict)
|
| 317 |
seg2.translation = trans_seg2
|
| 318 |
|
| 319 |
result_list = []
|
|
|
|
| 336 |
for i, seg in enumerate(self.segments):
|
| 337 |
if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
|
| 338 |
seg_list = self.split_seg(seg, text_threshold, time_threshold)
|
| 339 |
+
logging.info("splitting segment {} in to {} parts".format(i + 1, len(seg_list)))
|
| 340 |
segments += seg_list
|
| 341 |
else:
|
| 342 |
segments.append(seg)
|
|
|
|
| 368 |
## force term correction
|
| 369 |
logging.info("performing force term correction")
|
| 370 |
# load term dictionary
|
| 371 |
+
with open("../finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
| 372 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
| 373 |
+
|
| 374 |
keywords = list(term_enzh_dict.keys())
|
| 375 |
keywords.sort(key=lambda x: len(x), reverse=True)
|
| 376 |
|
| 377 |
for word in keywords:
|
| 378 |
for i, seg in enumerate(self.segments):
|
| 379 |
if word in seg.source_text.lower():
|
| 380 |
+
seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)),
|
| 381 |
+
seg.source_text, flags=re.IGNORECASE)
|
| 382 |
+
logging.info(
|
| 383 |
+
"replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(
|
| 384 |
+
i + 1))
|
| 385 |
logging.info("source text becomes: " + seg.source_text)
|
| 386 |
+
|
|
|
|
| 387 |
comp_dict = []
|
| 388 |
+
|
| 389 |
+
def fetchfunc(self, word, threshold):
|
| 390 |
import enchant
|
| 391 |
result = word
|
| 392 |
distance = 0
|
| 393 |
+
threshold = threshold * len(word)
|
| 394 |
+
if len(self.comp_dict) == 0:
|
| 395 |
with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
|
| 396 |
+
self.comp_dict = {rows[0]: 1 for rows in reader(f)}
|
| 397 |
temp = ""
|
| 398 |
for matched in self.comp_dict:
|
| 399 |
if (" " in matched and " " in word) or (" " not in matched and " " not in word):
|
| 400 |
+
if enchant.utils.levenshtein(word, matched) < enchant.utils.levenshtein(word, temp):
|
| 401 |
temp = matched
|
| 402 |
if enchant.utils.levenshtein(word, temp) < threshold:
|
| 403 |
distance = enchant.utils.levenshtein(word, temp)
|
| 404 |
result = temp
|
| 405 |
+
return distance, result
|
| 406 |
|
| 407 |
def extract_words(self, sentence, n):
|
| 408 |
# this function split the sentence to chunks by n of words
|
|
|
|
| 411 |
words = sentence.split()
|
| 412 |
res = []
|
| 413 |
for j in range(n, 0, -1):
|
| 414 |
+
res += [words[i:i + j] for i in range(len(words) - j + 1)]
|
| 415 |
+
return res
|
| 416 |
+
|
| 417 |
def spell_check_term(self):
|
| 418 |
logging.info("performing spell check")
|
| 419 |
import enchant
|
|
|
|
| 429 |
distance, correct_term = self.fetchfunc(real_word, 0.3)
|
| 430 |
if distance != 0:
|
| 431 |
seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
|
| 432 |
+
logging.info(
|
| 433 |
+
"replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance))
|
| 434 |
|
| 435 |
+
def get_real_word(self, word_list: list):
|
|
|
|
| 436 |
word = ""
|
| 437 |
for w in word_list:
|
| 438 |
word += f"{w} "
|
| 439 |
+
word = word[:-1] # "this, is"
|
| 440 |
if word[-2:] == ".\n":
|
| 441 |
real_word = word[:-2].lower()
|
| 442 |
n = -2
|
|
|
|
| 454 |
# return a string with pure source text
|
| 455 |
result = ""
|
| 456 |
for i, seg in enumerate(self.segments):
|
| 457 |
+
result += f'{seg.source_text}\n\n\n' # f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
|
| 458 |
+
|
| 459 |
return result
|
| 460 |
|
| 461 |
def reform_src_str(self):
|
srt2ass.py → srt_util/srt2ass.py
RENAMED
|
File without changes
|