Demo / modules /translation /translation_base.py
LAP-DEV's picture
Update modules/translation/translation_base.py
61278e3 verified
import os
import re
import torch
import gradio as gr
from abc import ABC, abstractmethod
from typing import List
from datetime import datetime
from modules.whisper.whisper_parameter import *
from modules.utils.subtitle_manager import *
from modules.utils.files_manager import load_yaml, save_yaml
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
class TranslationBase(ABC):
def __init__(self,
model_dir: str = NLLB_MODELS_DIR,
output_dir: str = TRANSLATION_OUTPUT_DIR
):
super().__init__()
self.model = None
self.model_dir = model_dir
self.output_dir = output_dir
os.makedirs(self.model_dir, exist_ok=True)
os.makedirs(self.output_dir, exist_ok=True)
self.current_model_size = None
self.device = self.get_device()
@abstractmethod
def translate(self,
text: str,
max_length: int
):
pass
@abstractmethod
def update_model(self,
model_size: str,
src_lang: str,
tgt_lang: str,
progress: gr.Progress = gr.Progress()
):
pass
def translate_file(self,
fileobjs: list,
model_size: str,
src_lang: str,
tgt_lang: str,
max_length: int = 200,
add_timestamp: bool = True,
progress=gr.Progress()) -> list:
"""
Translate subtitle file from source language to target language
Parameters
----------
fileobjs: list
List of files to transcribe from gr.Files()
model_size: str
Whisper model size from gr.Dropdown()
src_lang: str
Source language of the file to translate from gr.Dropdown()
tgt_lang: str
Target language of the file to translate from gr.Dropdown()
max_length: int
Max length per line to translate
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
Returns
----------
A List of
String to return to gr.Textbox()
Files to return to gr.Files()
"""
try:
if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
fileobjs = [file.name for file in fileobjs]
self.cache_parameters(model_size=model_size,
src_lang=src_lang,
tgt_lang=tgt_lang,
max_length=max_length,
add_timestamp=add_timestamp)
self.update_model(model_size=model_size,
src_lang=src_lang,
tgt_lang=tgt_lang,
progress=progress)
files_info = {}
for fileobj in fileobjs:
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
if file_ext == ".srt":
parsed_dicts = parse_srt(file_path=fileobj)
total_progress = len(parsed_dicts)
for index, dic in enumerate(parsed_dicts):
progress(index / total_progress, desc="Translating...")
translated_text = self.translate(dic["sentence"], max_length=max_length)
dic["sentence"] = translated_text
subtitle = get_serialized_srt(parsed_dicts)
elif file_ext == ".vtt":
parsed_dicts = parse_vtt(file_path=fileobj)
total_progress = len(parsed_dicts)
for index, dic in enumerate(parsed_dicts):
progress(index / total_progress, desc="Translating...")
translated_text = self.translate(dic["sentence"], max_length=max_length)
dic["sentence"] = translated_text
subtitle = get_serialized_vtt(parsed_dicts)
if add_timestamp:
timestamp = datetime.now().strftime("%m%d%H%M%S")
file_name += f"-{timestamp}"
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
write_file(subtitle, output_path)
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
total_result = ''
for file_name, info in files_info.items():
total_result += '------------------------------------\n'
total_result += f'{file_name}\n\n'
total_result += f'{info["subtitle"]}'
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
output_file_paths = [item["path"] for key, item in files_info.items()]
return [gr_str, output_file_paths]
except Exception as e:
print(f"Error: {str(e)}")
finally:
self.release_cuda_memory()
def translate_text(self,
input_list_dict: list,
model_size: str,
src_lang: str,
tgt_lang: str,
speaker_diarization: bool = False,
max_length: int = 200,
add_timestamp: bool = True,
progress=gr.Progress()) -> list:
"""
Translate text from source language to target language
Parameters
----------
str_text: str
List[dict] to translate
model_size: str
Whisper model size from gr.Dropdown()
src_lang: str
Source language of the file to translate from gr.Dropdown()
tgt_lang: str
Target language of the file to translate from gr.Dropdown()
speaker_diarization: bool
Boolean value that determines whether diarization is enabled or not
max_length: int
Max length per line to translate
add_timestamp: bool
Boolean value that determines whether to add a timestamp
progress: gr.Progress
Indicator to show progress directly in gradio.
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
Returns
----------
A List of
List[dict] with translation
"""
try:
if src_lang != tgt_lang:
self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
total_progress = len(input_list_dict)
for index, dic in enumerate(input_list_dict):
progress(index / total_progress, desc="Translating...")
# Add speaker ID to translated sentence when diarization is enabled
if speaker_diarization:
translated_text = ((dic['text']).split(":", 1)[0]).strip() + ": " + self.translate(((dic['text']).split(":", 1)[1]).strip(), max_length=max_length)
else:
translated_text = self.translate(dic["text"], max_length=max_length)
dic["text"] = translated_text
return input_list_dict
except Exception as e:
print(f"Error translating text: {e}")
raise
finally:
self.release_cuda_memory()
def offload(self):
"""Offload the model and free up the memory"""
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
self.release_cuda_memory()
gc.collect()
@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
@staticmethod
def release_cuda_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@staticmethod
def remove_input_files(file_paths: List[str]):
if not file_paths:
return
for file_path in file_paths:
if file_path and os.path.exists(file_path):
os.remove(file_path)
@staticmethod
def cache_parameters(model_size: str,
src_lang: str,
tgt_lang: str,
max_length: int,
add_timestamp: bool):
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
cached_params["translation"]["nllb"] = {
"model_size": model_size,
"source_lang": src_lang,
"target_lang": tgt_lang,
"max_length": max_length,
}
cached_params["translation"]["add_timestamp"] = add_timestamp
save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)