Spaces:
Sleeping
Sleeping
| from datetime import timedelta | |
| from csv import reader | |
| class SRT_segment(object): | |
| def __init__(self, *args) -> None: | |
| if isinstance(args[0], dict): | |
| segment = args[0] | |
| self.start = segment['start'] | |
| self.end = segment['end'] | |
| self.start_ms = int((segment['start']*100)%100*10) | |
| self.end_ms = int((segment['end']*100)%100*10) | |
| if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp | |
| self.end_ms+=500 | |
| self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms) | |
| self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms) | |
| if self.start_ms == 0: | |
| self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000' | |
| else: | |
| self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+self.start_time.split('.')[1][:3] | |
| if self.end_ms == 0: | |
| self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000' | |
| else: | |
| self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+self.end_time.split('.')[1][:3] | |
| self.source_text = segment['text'] | |
| self.duration = f"{self.start_time_str} --> {self.end_time_str}" | |
| self.translation = "" | |
| elif isinstance(args[0], list): | |
| self.source_text = args[0][2] | |
| self.duration = args[0][1] | |
| self.start_time_str = self.duration.split(" --> ")[0] | |
| self.end_time_str = self.duration.split(" --> ")[1] | |
| self.translation = "" | |
| def merge_seg(self, seg): | |
| self.source_text += seg.source_text | |
| self.translation += seg.translation | |
| self.end_time_str = seg.end_time_str | |
| self.duration = f"{self.start_time_str} --> {self.end_time_str}" | |
| pass | |
| def __str__(self) -> str: | |
| return f'{self.duration}\n{self.source_text}\n\n' | |
| def get_trans_str(self) -> str: | |
| return f'{self.duration}\n{self.translation}\n\n' | |
| def get_bilingual_str(self) -> str: | |
| return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' | |
| class SRT_script(): | |
| def __init__(self, segments) -> None: | |
| self.segments = [] | |
| for seg in segments: | |
| srt_seg = SRT_segment(seg) | |
| self.segments.append(srt_seg) | |
| def parse_from_srt_file(cls, path:str): | |
| with open(path, 'r', encoding="utf-8") as f: | |
| script_lines = f.read().splitlines() | |
| segments = [] | |
| for i in range(len(script_lines)): | |
| if i % 4 == 0: | |
| segments.append(list(script_lines[i:i+4])) | |
| return cls(segments) | |
| def merge_segs(self, idx_list) -> SRT_segment: | |
| final_seg = self.segments[idx_list[0]] | |
| if len(idx_list) == 1: | |
| return final_seg | |
| for idx in range(1, len(idx_list)): | |
| final_seg.merge_seg(self.segments[idx_list[idx]]) | |
| return final_seg | |
| def form_whole_sentence(self): | |
| merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] | |
| sentence = [] | |
| for i, seg in enumerate(self.segments): | |
| if seg.source_text[-1] == '.': | |
| sentence.append(i) | |
| merge_list.append(sentence) | |
| sentence = [] | |
| else: | |
| sentence.append(i) | |
| segments = [] | |
| for idx_list in merge_list: | |
| segments.append(self.merge_segs(idx_list)) | |
| self.segments = segments # need memory release? | |
| def set_translation(self, translate:str, id_range:tuple): | |
| start_seg_id = id_range[0] | |
| end_seg_id = id_range[1] | |
| lines = translate.split('\n\n') | |
| if len(lines) != (end_seg_id - start_seg_id + 1): | |
| print(id_range) | |
| for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): | |
| print(seg.source_text) | |
| print(translate) | |
| for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): | |
| # naive way to due with merge translation problem | |
| # TODO: need a smarter solution | |
| if i < len(lines): | |
| if "(Note:" in lines[i]: # to avoid note | |
| lines.remove(lines[i]) | |
| if i == len(lines) - 1: | |
| break | |
| try: | |
| seg.translation = lines[i].split(":" or ": ")[1] | |
| except: | |
| seg.translation = lines[i] | |
| #print(lines[i]) | |
| pass | |
| def split_seg(self, seg_idx): | |
| # TODO: evenly split seg to 2 parts and add new seg into self.segments | |
| seg = self.segments[seg_idx] | |
| source_text = seg.source_text | |
| translation = seg.translation | |
| src_commas = [m.start() for m in re.finditer(',', source_text)] | |
| trans_commas = [m.start() for m in re.finditer(',', translation)] | |
| src_split_idx = src_commas[len(src_commas)//2 + 1] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2] | |
| trans_split_idx = trans_commas[len(src_commas)//2 + 1] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2] | |
| src_seg1 = source_text[:src_split_idx] | |
| src_seg2 = source_text[src_split_idx+1:] | |
| trans_seg1 = translation[:trans_split_idx] | |
| trans_seg2 = translation[trans_split_idx+1:] | |
| start_seg1 = seg.start | |
| end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2 | |
| end_seg2 = seg.end | |
| seg1_dict = {} | |
| seg1_dict['text'] = src_seg1 | |
| seg1_dict['start'] = start_seg1 | |
| seg1_dict['end'] = end_seg1 | |
| seg1 = SRT_segment(seg1_dict) | |
| seg1.translation = trans_seg1 | |
| seg2_dict = {} | |
| seg2_dict['text'] = src_seg2 | |
| seg2_dict['start'] = start_seg2 | |
| seg2_dict['end'] = end_seg2 | |
| seg2 = SRT_segment(seg2_dict) | |
| seg2.translation = trans_seg2 | |
| pass | |
| def check_len_and_split(self, threshold): | |
| # TODO: if sentence length >= threshold, split this segments to two | |
| pass | |
| def get_source_only(self): | |
| # return a string with pure source text | |
| result = "" | |
| for i, seg in enumerate(self.segments): | |
| result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n' | |
| return result | |
| def reform_src_str(self): | |
| result = "" | |
| for i, seg in enumerate(self.segments): | |
| result += f'{i+1}\n' | |
| result += str(seg) | |
| return result | |
| def reform_trans_str(self): | |
| result = "" | |
| for i, seg in enumerate(self.segments): | |
| result += f'{i+1}\n' | |
| result += seg.get_trans_str() | |
| return result | |
| def form_bilingual_str(self): | |
| result = "" | |
| for i, seg in enumerate(self.segments): | |
| result += f'{i+1}\n' | |
| result += seg.get_bilingual_str() | |
| return result | |
| def write_srt_file_src(self, path:str): | |
| # write srt file to path | |
| with open(path, "w", encoding='utf-8') as f: | |
| f.write(self.reform_src_str()) | |
| pass | |
| def write_srt_file_translate(self, path:str): | |
| with open(path, "w", encoding='utf-8') as f: | |
| f.write(self.reform_trans_str()) | |
| pass | |
| def write_srt_file_bilingual(self, path:str): | |
| with open(path, "w", encoding='utf-8') as f: | |
| f.write(self.form_bilingual_str()) | |
| pass | |
| def correct_with_force_term(self): | |
| ## force term correction | |
| # TODO: shortcut translation i.e. VA, ob | |
| # TODO: variety of translation | |
| # load term dictionary | |
| with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f: | |
| term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)} | |
| # change term | |
| for seg in self.segments: | |
| ready_words = seg.source_text.split(" ") | |
| for i in range(len(ready_words)): | |
| word = ready_words[i] | |
| [real_word, pos] = self.get_real_word(word) | |
| if real_word in term_enzh_dict: | |
| new_word = word.replace(word[:pos], term_enzh_dict.get(real_word)) | |
| else: | |
| new_word = word | |
| ready_words[i] = new_word | |
| # if word[-2:] == ".\n": | |
| # if word[:-2].lower() in term_enzh_dict: | |
| # new_word = word.replace(word[:-2], term_enzh_dict.get(word[:-2].lower())) | |
| # ready_words[i] = new_word | |
| # else: | |
| # ready_words[i] = word | |
| # elif word.lower() in term_enzh_dict: | |
| # new_word = word.replace(word,term_enzh_dict.get(word.lower())) | |
| # ready_words[i] = new_word | |
| # else: | |
| # ready_words[i]= word | |
| seg.source_text = " ".join(ready_words) | |
| pass | |
| def spell_check_term(self): | |
| ## known bug: I've will be replaced because i've is not in the dict | |
| import enchant | |
| dict = enchant.Dict('en_US') | |
| term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt') | |
| for seg in self.segments: | |
| ready_words = seg.source_text.split(" ") | |
| for i in range(len(ready_words)): | |
| word = ready_words[i] | |
| [real_word, pos] = self.get_real_word(word) | |
| if not dict.check(real_word): | |
| suggest = term_spellDict.suggest(real_word) | |
| if suggest: # relax spell check | |
| new_word = word.replace(word[:pos],suggest[0]) | |
| else: | |
| new_word = word | |
| ready_words[i] = new_word | |
| # if word[-2:] == ".\n": | |
| # real_word = word[:-2] | |
| # if not dict.check(real_word.lower()): | |
| # new_word = word.replace(word[:-2], term_spellDict.suggest(real_word.lower())[0]) | |
| # ready_words[i] = new_word | |
| # elif word[-1:] in [".", "\n", ","]: | |
| # real_word = word[:-1] | |
| # if not dict.check(real_word.lower()): | |
| # new_word = word.replace(word[:-1], term_spellDict.suggest(real_word.lower())[0]) | |
| # ready_words[i] = new_word | |
| # elif not dict.check(word.lower()): | |
| # new_word = word.replace(word,term_spellDict.suggest(word.lower())[0]) | |
| # ready_words[i] = new_word | |
| seg.source_text = " ".join(ready_words) | |
| pass | |
| def spell_correction(self, word:str, arg:int): | |
| try: | |
| arg in [0,1] | |
| except ValueError: | |
| print('only 0 or 1 for argument') | |
| def uncover(word:str): | |
| if word[-2:] == ".\n": | |
| real_word = word[:-2].lower() | |
| n = -2 | |
| elif word[-1:] in [".", "\n", ",", "!", "?"]: | |
| real_word = word[:-1].lower() | |
| n = -1 | |
| else: | |
| real_word = word.lower() | |
| n = 0 | |
| return real_word, len(word)+n | |
| real_word = uncover(word)[0] | |
| pos = uncover(word)[1] | |
| new_word = word | |
| if arg == 0: # term translate mode | |
| with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f: | |
| term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)} | |
| if real_word in term_enzh_dict: | |
| new_word = word.replace(word[:pos], term_enzh_dict.get(real_word)) | |
| elif arg == 1: # term spell check mode | |
| import enchant | |
| dict = enchant.Dict('en_US') | |
| term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt') | |
| if not dict.check(real_word): | |
| if term_spellDict.suggest(real_word): # relax spell check | |
| new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0]) | |
| return new_word | |
| def get_real_word(self, word:str): | |
| if word[-2:] == ".\n": | |
| real_word = word[:-2].lower() | |
| n = -2 | |
| elif word[-1:] in [".", "\n", ",", "!", "?"]: | |
| real_word = word[:-1].lower() | |
| n = -1 | |
| else: | |
| real_word = word.lower() | |
| n = 0 | |
| return real_word, len(word)+n |