Update modules/translation/translation_base.py
Browse files
modules/translation/translation_base.py
CHANGED
|
@@ -136,7 +136,8 @@ class TranslationBase(ABC):
|
|
| 136 |
finally:
|
| 137 |
self.release_cuda_memory()
|
| 138 |
|
| 139 |
-
def translate_text(
|
|
|
|
| 140 |
model_size: str,
|
| 141 |
src_lang: str,
|
| 142 |
tgt_lang: str,
|
|
@@ -170,13 +171,13 @@ class TranslationBase(ABC):
|
|
| 170 |
List[dict] with translation
|
| 171 |
"""
|
| 172 |
try:
|
| 173 |
-
cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
|
| 174 |
-
update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
|
| 175 |
|
| 176 |
total_progress = len(input_list_dict)
|
| 177 |
for index, dic in enumerate(input_list_dict):
|
| 178 |
progress(index / total_progress, desc="Translating..")
|
| 179 |
-
translated_text = translate(dic["text"], max_length=max_length)
|
| 180 |
dic["text"] = translated_text
|
| 181 |
|
| 182 |
return input_list_dict
|
|
@@ -184,7 +185,7 @@ class TranslationBase(ABC):
|
|
| 184 |
except Exception as e:
|
| 185 |
print(f"Error: {str(e)}")
|
| 186 |
finally:
|
| 187 |
-
release_cuda_memory()
|
| 188 |
|
| 189 |
@staticmethod
|
| 190 |
def get_device():
|
|
|
|
| 136 |
finally:
|
| 137 |
self.release_cuda_memory()
|
| 138 |
|
| 139 |
+
def translate_text(self,
|
| 140 |
+
input_list_dict: list,
|
| 141 |
model_size: str,
|
| 142 |
src_lang: str,
|
| 143 |
tgt_lang: str,
|
|
|
|
| 171 |
List[dict] with translation
|
| 172 |
"""
|
| 173 |
try:
|
| 174 |
+
self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
|
| 175 |
+
self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
|
| 176 |
|
| 177 |
total_progress = len(input_list_dict)
|
| 178 |
for index, dic in enumerate(input_list_dict):
|
| 179 |
progress(index / total_progress, desc="Translating..")
|
| 180 |
+
translated_text = self.translate(dic["text"], max_length=max_length)
|
| 181 |
dic["text"] = translated_text
|
| 182 |
|
| 183 |
return input_list_dict
|
|
|
|
| 185 |
except Exception as e:
|
| 186 |
print(f"Error: {str(e)}")
|
| 187 |
finally:
|
| 188 |
+
self.release_cuda_memory()
|
| 189 |
|
| 190 |
@staticmethod
|
| 191 |
def get_device():
|