Eason Lu commited on
Commit
e3f9642
·
1 Parent(s): 1a902ed

solve puncs for multilingual

Browse files

Former-commit-id: 2d6150e7eadae2735d85204e283cace5c12789e4

Files changed (1) hide show
  1. src/srt_util/srt.py +65 -22
src/srt_util/srt.py CHANGED
@@ -8,15 +8,48 @@ import logging
8
  import openai
9
  from tqdm import tqdm
10
 
 
11
  punctuation_dict = {
12
- "EN": ". , ? ! : ; - ( ) [ ] { } ' \"",
13
- "ES": ". , ? ! : ; - ( ) [ ] { } ' \" ¡ ¿",
14
- "FR": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
15
- "DE": ". , ? ! : ; - ( ) [ ] { } ' \" „ “ –",
16
- "RU": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
17
- "ZH": "。 , ? ! : ; — ( ) ​``【oaicite:1】``​ 《 》 “ ”",
18
- "JA": " ​``【oaicite:0】``​ ",
19
- "AR": ". , ? ! : ; - ( ) [ ] { } ، ؛ ؟ « »",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  }
21
 
22
  class SrtSegment(object):
@@ -99,7 +132,7 @@ class SrtSegment(object):
99
  remove punctuations in translation text
100
  :return: None
101
  """
102
- punc = punctuation_dict[self.tgt_lang]
103
  translator = str.maketrans(punc, ' ' * len(punc))
104
  self.translation = self.translation.translate(translator)
105
 
@@ -162,9 +195,10 @@ class SrtScript(object):
162
  logging.info("Forming whole sentences...")
163
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
164
  sentence = []
 
165
  # Get each entire sentence of distinct segments, fill indices to merge_list
166
  for i, seg in enumerate(self.segments):
167
- if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
168
  sentence.append(i)
169
  merge_list.append(sentence)
170
  sentence = []
@@ -199,6 +233,7 @@ class SrtScript(object):
199
  src_text += '\n\n'
200
 
201
  def inner_func(target, input_str):
 
202
  response = openai.ChatCompletion.create(
203
  model="gpt-4",
204
  messages=[
@@ -270,25 +305,27 @@ class SrtScript(object):
270
 
271
  def split_seg(self, seg, text_threshold, time_threshold):
272
  # evenly split seg to 2 parts and add new seg into self.segments
273
-
274
  # ignore the initial comma to solve the recursion problem
275
- # FIXME: accomodate multilingual setting
 
 
276
  if len(seg.source_text) > 2:
277
- if seg.source_text[:2] == ', ':
278
  seg.source_text = seg.source_text[2:]
279
- if seg.translation[0] == ',':
280
  seg.translation = seg.translation[1:]
281
 
282
  source_text = seg.source_text
283
  translation = seg.translation
284
 
285
  # split the text based on commas
286
- src_commas = [m.start() for m in re.finditer(',', source_text)]
287
- trans_commas = [m.start() for m in re.finditer(',', translation)]
288
  if len(src_commas) != 0:
289
  src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
290
  len(src_commas) // 2 - 1]
291
  else:
 
292
  src_space = [m.start() for m in re.finditer(' ', source_text)]
293
  if len(src_space) > 0:
294
  src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
@@ -300,13 +337,19 @@ class SrtScript(object):
300
  trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
301
  len(trans_commas) // 2 - 1]
302
  else:
303
- trans_split_idx = len(translation) // 2
 
 
 
 
 
 
304
 
305
- # to avoid split English word
306
- for i in range(trans_split_idx, len(translation)):
307
- if not translation[i].encode('utf-8').isalpha():
308
- trans_split_idx = i
309
- break
310
 
311
  # split the time duration based on text length
312
  time_split_ratio = trans_split_idx / (len(seg.translation) - 1)
 
8
  import openai
9
  from tqdm import tqdm
10
 
11
+ # punctuation dictionary for supported languages
12
  punctuation_dict = {
13
+ "EN": {
14
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \"",
15
+ "comma": ", ",
16
+ "sentence_end": [".", "!", "?", ";"]
17
+ },
18
+ "ES": {
19
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" ¡ ¿",
20
+ "comma": ", ",
21
+ "sentence_end": [".", "!", "?", ";", "¡", "¿"]
22
+ },
23
+ "FR": {
24
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
25
+ "comma": ", ",
26
+ "sentence_end": [".", "!", "?", ";"]
27
+ },
28
+ "DE": {
29
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" „ “ –",
30
+ "comma": ", ",
31
+ "sentence_end": [".", "!", "?", ";"]
32
+ },
33
+ "RU": {
34
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ' \" « » —",
35
+ "comma": ", ",
36
+ "sentence_end": [".", "!", "?", ";"]
37
+ },
38
+ "ZH": {
39
+ "punc_str": "。 , ? ! : ; — ( ) ​``【oaicite:1】``​ 《 》 “ ”",
40
+ "comma": ",",
41
+ "sentence_end": ["。", "!", "?"]
42
+ },
43
+ "JA": {
44
+ "punc_str": "。 、 ? ! : ; ー ( ) ​``【oaicite:0】``​ 「 」 『 』",
45
+ "comma": "、",
46
+ "sentence_end": ["。", "!", "?"]
47
+ },
48
+ "AR": {
49
+ "punc_str": ". , ? ! : ; - ( ) [ ] { } ، ؛ ؟ « »",
50
+ "comma": "، ",
51
+ "sentence_end": [".", "!", "?", ";", "؟"]
52
+ },
53
  }
54
 
55
  class SrtSegment(object):
 
132
  remove punctuations in translation text
133
  :return: None
134
  """
135
+ punc = punctuation_dict[self.tgt_lang]["punc_str"]
136
  translator = str.maketrans(punc, ' ' * len(punc))
137
  self.translation = self.translation.translate(translator)
138
 
 
195
  logging.info("Forming whole sentences...")
196
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
197
  sentence = []
198
+ ending_puncs = punctuation_dict[self.src_lang]["sentence_end"]
199
  # Get each entire sentence of distinct segments, fill indices to merge_list
200
  for i, seg in enumerate(self.segments):
201
+ if seg.source_text[-1] in ending_puncs and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
202
  sentence.append(i)
203
  merge_list.append(sentence)
204
  sentence = []
 
233
  src_text += '\n\n'
234
 
235
  def inner_func(target, input_str):
236
+ # TODO: accomodate different languages
237
  response = openai.ChatCompletion.create(
238
  model="gpt-4",
239
  messages=[
 
305
 
306
  def split_seg(self, seg, text_threshold, time_threshold):
307
  # evenly split seg to 2 parts and add new seg into self.segments
 
308
  # ignore the initial comma to solve the recursion problem
309
+ src_comma_str = punctuation_dict[self.src_lang]["comma"]
310
+ tgt_comma_str = punctuation_dict[self.tgt_lang]["comma"]
311
+
312
  if len(seg.source_text) > 2:
313
+ if seg.source_text[:2] == src_comma_str:
314
  seg.source_text = seg.source_text[2:]
315
+ if seg.translation[0] == tgt_comma_str:
316
  seg.translation = seg.translation[1:]
317
 
318
  source_text = seg.source_text
319
  translation = seg.translation
320
 
321
  # split the text based on commas
322
+ src_commas = [m.start() for m in re.finditer(src_comma_str, source_text)]
323
+ trans_commas = [m.start() for m in re.finditer(tgt_comma_str, translation)]
324
  if len(src_commas) != 0:
325
  src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
326
  len(src_commas) // 2 - 1]
327
  else:
328
+ # split the text based on spaces
329
  src_space = [m.start() for m in re.finditer(' ', source_text)]
330
  if len(src_space) > 0:
331
  src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
 
337
  trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
338
  len(trans_commas) // 2 - 1]
339
  else:
340
+ # split the text based on spaces
341
+ trans_space = [m.start() for m in re.finditer(' ', translation)]
342
+ if len(trans_space) > 0:
343
+ trans_split_idx = trans_space[len(trans_space) // 2] if len(trans_space) % 2 == 1 else trans_space[
344
+ len(trans_space) // 2 - 1]
345
+ else:
346
+ trans_split_idx = len(translation) // 2
347
 
348
+ # to avoid split English word
349
+ for i in range(trans_split_idx, len(translation)):
350
+ if not translation[i].encode('utf-8').isalpha():
351
+ trans_split_idx = i
352
+ break
353
 
354
  # split the time duration based on text length
355
  time_split_ratio = trans_split_idx / (len(seg.translation) - 1)