File size: 11,851 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
import re
import time
import asyncio
from typing import List, Tuple
from abc import abstractmethod
from ..utils import InfererModule, ModelWrapper, repeating_sequence, is_valuable_text
try:
import readline
except Exception:
readline = None
VALID_LANGUAGES = {
'CHS': 'Chinese (Simplified)',
'CHT': 'Chinese (Traditional)',
'CSY': 'Czech',
'NLD': 'Dutch',
'ENG': 'English',
'FRA': 'French',
'DEU': 'German',
'HUN': 'Hungarian',
'ITA': 'Italian',
'JPN': 'Japanese',
'KOR': 'Korean',
'PLK': 'Polish',
'PTB': 'Portuguese (Brazil)',
'ROM': 'Romanian',
'RUS': 'Russian',
'ESP': 'Spanish',
'TRK': 'Turkish',
'UKR': 'Ukrainian',
'VIN': 'Vietnamese',
'ARA': 'Arabic',
'CNR': 'Montenegrin',
'SRP': 'Serbian',
'HRV': 'Croatian',
'THA': 'Thai',
'IND': 'Indonesian',
'FIL': 'Filipino (Tagalog)'
}
ISO_639_1_TO_VALID_LANGUAGES = {
'zh': 'CHS',
'ja': 'JPN',
'en': 'ENG',
'ko': 'KOR',
'vi': 'VIN',
'cs': 'CSY',
'nl': 'NLD',
'fr': 'FRA',
'de': 'DEU',
'hu': 'HUN',
'it': 'ITA',
'pl': 'PLK',
'pt': 'PTB',
'ro': 'ROM',
'ru': 'RUS',
'es': 'ESP',
'tr': 'TRK',
'uk': 'UKR',
'vi': 'VIN',
'ar': 'ARA',
'cnr': 'CNR',
'sr': 'SRP',
'hr': 'HRV',
'th': 'THA',
'id': 'IND',
'tl': 'FIL'
}
class InvalidServerResponse(Exception):
pass
class MissingAPIKeyException(Exception):
pass
class LanguageUnsupportedException(Exception):
def __init__(self, language_code: str, translator: str = None, supported_languages: List[str] = None):
error = 'Language not supported for %s: "%s"' % (translator if translator else 'chosen translator', language_code)
if supported_languages:
error += '. Supported languages: "%s"' % ','.join(supported_languages)
super().__init__(error)
class MTPEAdapter():
async def dispatch(self, queries: List[str], translations: List[str]) -> List[str]:
# TODO: Make it work in windows (e.g. through os.startfile)
if not readline:
print('MTPE is currently only supported on linux')
return translations
new_translations = []
print('Running Machine Translation Post Editing (MTPE)')
for i, (query, translation) in enumerate(zip(queries, translations)):
print(f'\n[{i + 1}/{len(queries)}] {query}:')
readline.set_startup_hook(lambda: readline.insert_text(translation.replace('\n', '\\n')))
new_translation = ''
try:
new_translation = input(' -> ').replace('\\n', '\n')
finally:
readline.set_startup_hook()
new_translations.append(new_translation)
print()
return new_translations
class CommonTranslator(InfererModule):
# Translator has to support all languages listed in here. The language codes will be resolved into
# _LANGUAGE_CODE_MAP[lang_code] automatically if _LANGUAGE_CODE_MAP is a dict.
# If it is a list it will simply return the language code as is.
_LANGUAGE_CODE_MAP = {}
# The amount of repeats upon detecting an invalid translation.
# Use with _is_translation_invalid and _modify_invalid_translation_query.
_INVALID_REPEAT_COUNT = 0
# Will sleep for the rest of the minute if the request count is over this number.
_MAX_REQUESTS_PER_MINUTE = -1
def __init__(self):
super().__init__()
self.mtpe_adapter = MTPEAdapter()
self._last_request_ts = 0
def supports_languages(self, from_lang: str, to_lang: str, fatal: bool = False) -> bool:
supported_src_languages = ['auto'] + list(self._LANGUAGE_CODE_MAP)
supported_tgt_languages = list(self._LANGUAGE_CODE_MAP)
if from_lang not in supported_src_languages:
if fatal:
raise LanguageUnsupportedException(from_lang, self.__class__.__name__, supported_src_languages)
return False
if to_lang not in supported_tgt_languages:
if fatal:
raise LanguageUnsupportedException(to_lang, self.__class__.__name__, supported_tgt_languages)
return False
return True
def parse_language_codes(self, from_lang: str, to_lang: str, fatal: bool = False) -> Tuple[str, str]:
if not self.supports_languages(from_lang, to_lang, fatal):
return None, None
if type(self._LANGUAGE_CODE_MAP) is list:
return from_lang, to_lang
_from_lang = self._LANGUAGE_CODE_MAP.get(from_lang) if from_lang != 'auto' else 'auto'
_to_lang = self._LANGUAGE_CODE_MAP.get(to_lang)
return _from_lang, _to_lang
async def translate(self, from_lang: str, to_lang: str, queries: List[str], use_mtpe: bool = False) -> List[str]:
"""
Translates list of queries of one language into another.
"""
if to_lang not in VALID_LANGUAGES:
raise ValueError('Invalid language code: "%s". Choose from the following: %s' % (to_lang, ', '.join(VALID_LANGUAGES)))
if from_lang not in VALID_LANGUAGES and from_lang != 'auto':
raise ValueError('Invalid language code: "%s". Choose from the following: auto, %s' % (from_lang, ', '.join(VALID_LANGUAGES)))
self.logger.info(f'Translating into {VALID_LANGUAGES[to_lang]}')
if from_lang == to_lang:
return queries
# Dont translate queries without text
query_indices = []
final_translations = []
for i, query in enumerate(queries):
if not is_valuable_text(query):
final_translations.append(queries[i])
else:
final_translations.append(None)
query_indices.append(i)
queries = [queries[i] for i in query_indices]
translations = [''] * len(queries)
untranslated_indices = list(range(len(queries)))
for i in range(1 + self._INVALID_REPEAT_COUNT): # Repeat until all translations are considered valid
if i > 0:
self.logger.warn(f'Repeating because of invalid translation. Attempt: {i+1}')
await asyncio.sleep(0.1)
# Sleep if speed is over the ratelimit
await self._ratelimit_sleep()
# Translate
_translations = await self._translate(*self.parse_language_codes(from_lang, to_lang, fatal=True), queries)
# Extend returned translations list to have the same size as queries
if len(_translations) < len(queries):
_translations.extend([''] * (len(queries) - len(_translations)))
elif len(_translations) > len(queries):
_translations = _translations[:len(queries)]
# Only overwrite yet untranslated indices
for j in untranslated_indices:
translations[j] = _translations[j]
if self._INVALID_REPEAT_COUNT == 0:
break
new_untranslated_indices = []
for j in untranslated_indices:
q, t = queries[j], translations[j]
# Repeat invalid translations with slightly modified queries
if self._is_translation_invalid(q, t):
new_untranslated_indices.append(j)
queries[j] = self._modify_invalid_translation_query(q, t)
untranslated_indices = new_untranslated_indices
if not untranslated_indices:
break
translations = [self._clean_translation_output(q, r, to_lang) for q, r in zip(queries, translations)]
if to_lang == 'ARA':
import arabic_reshaper
translations = [arabic_reshaper.reshape(t) for t in translations]
if use_mtpe:
translations = await self.mtpe_adapter.dispatch(queries, translations)
# Merge with the queries without text
for i, trans in enumerate(translations):
final_translations[query_indices[i]] = trans
self.logger.info(f'{i}: {queries[i]} => {trans}')
return final_translations
@abstractmethod
async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
pass
async def _ratelimit_sleep(self):
if self._MAX_REQUESTS_PER_MINUTE > 0:
now = time.time()
ratelimit_timeout = self._last_request_ts + 60 / self._MAX_REQUESTS_PER_MINUTE
if ratelimit_timeout > now:
self.logger.info(f'Ratelimit sleep: {(ratelimit_timeout-now):.2f}s')
await asyncio.sleep(ratelimit_timeout-now)
self._last_request_ts = time.time()
def _is_translation_invalid(self, query: str, trans: str) -> bool:
if not trans and query:
return True
if not query or not trans:
return False
query_symbols_count = len(set(query))
trans_symbols_count = len(set(trans))
if query_symbols_count > 6 and trans_symbols_count < 6 and trans_symbols_count < 0.25 * len(trans):
return True
return False
def _modify_invalid_translation_query(self, query: str, trans: str) -> str:
"""
Can be overwritten if _INVALID_REPEAT_COUNT was set. It modifies the query
for the next translation attempt.
"""
return query
def _clean_translation_output(self, query: str, trans: str, to_lang: str) -> str:
"""
Tries to spot and skim down invalid translations.
"""
if not query or not trans:
return ''
# ' ' -> ' '
trans = re.sub(r'\s+', r' ', trans)
# 'text.text' -> 'text. text'
trans = re.sub(r'(?<![.,;!?])([.,;!?])(?=\w)', r'\1 ', trans)
# ' ! ! . . ' -> ' !!.. '
trans = re.sub(r'([.,;!?])\s+(?=[.,;!?]|$)', r'\1', trans)
if to_lang != 'ARA':
# 'text .' -> 'text.'
trans = re.sub(r'(?<=[.,;!?\w])\s+([.,;!?])', r'\1', trans)
# ' ... text' -> ' ...text'
trans = re.sub(r'((?:\s|^)\.+)\s+(?=\w)', r'\1', trans)
seq = repeating_sequence(trans.lower())
# 'aaaaaaaaaaaaa' -> 'aaaaaa'
if len(trans) < len(query) and len(seq) < 0.5 * len(trans):
# Shrink sequence to length of original query
trans = seq * max(len(query) // len(seq), 1)
# Transfer capitalization of query to translation
nTrans = ''
for i in range(min(len(trans), len(query))):
nTrans += trans[i].upper() if query[i].isupper() else trans[i]
trans = nTrans
# words = text.split()
# elements = list(set(words))
# if len(elements) / len(words) < 0.1:
# words = words[:int(len(words) / 1.75)]
# text = ' '.join(words)
# # For words that appear more then four times consecutively, remove the excess
# for el in elements:
# el = re.escape(el)
# text = re.sub(r'(?: ' + el + r'){4} (' + el + r' )+', ' ', text)
return trans
class OfflineTranslator(CommonTranslator, ModelWrapper):
_MODEL_SUB_DIR = 'translators'
async def _translate(self, *args, **kwargs):
return await self.infer(*args, **kwargs)
@abstractmethod
async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
pass
async def load(self, from_lang: str, to_lang: str, device: str):
return await super().load(device, *self.parse_language_codes(from_lang, to_lang))
@abstractmethod
async def _load(self, from_lang: str, to_lang: str, device: str):
pass
async def reload(self, from_lang: str, to_lang: str, device: str):
return await super().reload(device, from_lang, to_lang)
|