Sunday01's picture
up
9dce458
import re
import time
import asyncio
from typing import List, Tuple
from abc import abstractmethod
from ..utils import InfererModule, ModelWrapper, repeating_sequence, is_valuable_text
try:
import readline
except Exception:
readline = None
VALID_LANGUAGES = {
'CHS': 'Chinese (Simplified)',
'CHT': 'Chinese (Traditional)',
'CSY': 'Czech',
'NLD': 'Dutch',
'ENG': 'English',
'FRA': 'French',
'DEU': 'German',
'HUN': 'Hungarian',
'ITA': 'Italian',
'JPN': 'Japanese',
'KOR': 'Korean',
'PLK': 'Polish',
'PTB': 'Portuguese (Brazil)',
'ROM': 'Romanian',
'RUS': 'Russian',
'ESP': 'Spanish',
'TRK': 'Turkish',
'UKR': 'Ukrainian',
'VIN': 'Vietnamese',
'ARA': 'Arabic',
'CNR': 'Montenegrin',
'SRP': 'Serbian',
'HRV': 'Croatian',
'THA': 'Thai',
'IND': 'Indonesian',
'FIL': 'Filipino (Tagalog)'
}
ISO_639_1_TO_VALID_LANGUAGES = {
'zh': 'CHS',
'ja': 'JPN',
'en': 'ENG',
'ko': 'KOR',
'vi': 'VIN',
'cs': 'CSY',
'nl': 'NLD',
'fr': 'FRA',
'de': 'DEU',
'hu': 'HUN',
'it': 'ITA',
'pl': 'PLK',
'pt': 'PTB',
'ro': 'ROM',
'ru': 'RUS',
'es': 'ESP',
'tr': 'TRK',
'uk': 'UKR',
'vi': 'VIN',
'ar': 'ARA',
'cnr': 'CNR',
'sr': 'SRP',
'hr': 'HRV',
'th': 'THA',
'id': 'IND',
'tl': 'FIL'
}
class InvalidServerResponse(Exception):
pass
class MissingAPIKeyException(Exception):
pass
class LanguageUnsupportedException(Exception):
def __init__(self, language_code: str, translator: str = None, supported_languages: List[str] = None):
error = 'Language not supported for %s: "%s"' % (translator if translator else 'chosen translator', language_code)
if supported_languages:
error += '. Supported languages: "%s"' % ','.join(supported_languages)
super().__init__(error)
class MTPEAdapter():
async def dispatch(self, queries: List[str], translations: List[str]) -> List[str]:
# TODO: Make it work in windows (e.g. through os.startfile)
if not readline:
print('MTPE is currently only supported on linux')
return translations
new_translations = []
print('Running Machine Translation Post Editing (MTPE)')
for i, (query, translation) in enumerate(zip(queries, translations)):
print(f'\n[{i + 1}/{len(queries)}] {query}:')
readline.set_startup_hook(lambda: readline.insert_text(translation.replace('\n', '\\n')))
new_translation = ''
try:
new_translation = input(' -> ').replace('\\n', '\n')
finally:
readline.set_startup_hook()
new_translations.append(new_translation)
print()
return new_translations
class CommonTranslator(InfererModule):
# Translator has to support all languages listed in here. The language codes will be resolved into
# _LANGUAGE_CODE_MAP[lang_code] automatically if _LANGUAGE_CODE_MAP is a dict.
# If it is a list it will simply return the language code as is.
_LANGUAGE_CODE_MAP = {}
# The amount of repeats upon detecting an invalid translation.
# Use with _is_translation_invalid and _modify_invalid_translation_query.
_INVALID_REPEAT_COUNT = 0
# Will sleep for the rest of the minute if the request count is over this number.
_MAX_REQUESTS_PER_MINUTE = -1
def __init__(self):
super().__init__()
self.mtpe_adapter = MTPEAdapter()
self._last_request_ts = 0
def supports_languages(self, from_lang: str, to_lang: str, fatal: bool = False) -> bool:
supported_src_languages = ['auto'] + list(self._LANGUAGE_CODE_MAP)
supported_tgt_languages = list(self._LANGUAGE_CODE_MAP)
if from_lang not in supported_src_languages:
if fatal:
raise LanguageUnsupportedException(from_lang, self.__class__.__name__, supported_src_languages)
return False
if to_lang not in supported_tgt_languages:
if fatal:
raise LanguageUnsupportedException(to_lang, self.__class__.__name__, supported_tgt_languages)
return False
return True
def parse_language_codes(self, from_lang: str, to_lang: str, fatal: bool = False) -> Tuple[str, str]:
if not self.supports_languages(from_lang, to_lang, fatal):
return None, None
if type(self._LANGUAGE_CODE_MAP) is list:
return from_lang, to_lang
_from_lang = self._LANGUAGE_CODE_MAP.get(from_lang) if from_lang != 'auto' else 'auto'
_to_lang = self._LANGUAGE_CODE_MAP.get(to_lang)
return _from_lang, _to_lang
async def translate(self, from_lang: str, to_lang: str, queries: List[str], use_mtpe: bool = False) -> List[str]:
"""
Translates list of queries of one language into another.
"""
if to_lang not in VALID_LANGUAGES:
raise ValueError('Invalid language code: "%s". Choose from the following: %s' % (to_lang, ', '.join(VALID_LANGUAGES)))
if from_lang not in VALID_LANGUAGES and from_lang != 'auto':
raise ValueError('Invalid language code: "%s". Choose from the following: auto, %s' % (from_lang, ', '.join(VALID_LANGUAGES)))
self.logger.info(f'Translating into {VALID_LANGUAGES[to_lang]}')
if from_lang == to_lang:
return queries
# Dont translate queries without text
query_indices = []
final_translations = []
for i, query in enumerate(queries):
if not is_valuable_text(query):
final_translations.append(queries[i])
else:
final_translations.append(None)
query_indices.append(i)
queries = [queries[i] for i in query_indices]
translations = [''] * len(queries)
untranslated_indices = list(range(len(queries)))
for i in range(1 + self._INVALID_REPEAT_COUNT): # Repeat until all translations are considered valid
if i > 0:
self.logger.warn(f'Repeating because of invalid translation. Attempt: {i+1}')
await asyncio.sleep(0.1)
# Sleep if speed is over the ratelimit
await self._ratelimit_sleep()
# Translate
_translations = await self._translate(*self.parse_language_codes(from_lang, to_lang, fatal=True), queries)
# Extend returned translations list to have the same size as queries
if len(_translations) < len(queries):
_translations.extend([''] * (len(queries) - len(_translations)))
elif len(_translations) > len(queries):
_translations = _translations[:len(queries)]
# Only overwrite yet untranslated indices
for j in untranslated_indices:
translations[j] = _translations[j]
if self._INVALID_REPEAT_COUNT == 0:
break
new_untranslated_indices = []
for j in untranslated_indices:
q, t = queries[j], translations[j]
# Repeat invalid translations with slightly modified queries
if self._is_translation_invalid(q, t):
new_untranslated_indices.append(j)
queries[j] = self._modify_invalid_translation_query(q, t)
untranslated_indices = new_untranslated_indices
if not untranslated_indices:
break
translations = [self._clean_translation_output(q, r, to_lang) for q, r in zip(queries, translations)]
if to_lang == 'ARA':
import arabic_reshaper
translations = [arabic_reshaper.reshape(t) for t in translations]
if use_mtpe:
translations = await self.mtpe_adapter.dispatch(queries, translations)
# Merge with the queries without text
for i, trans in enumerate(translations):
final_translations[query_indices[i]] = trans
self.logger.info(f'{i}: {queries[i]} => {trans}')
return final_translations
@abstractmethod
async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
pass
async def _ratelimit_sleep(self):
if self._MAX_REQUESTS_PER_MINUTE > 0:
now = time.time()
ratelimit_timeout = self._last_request_ts + 60 / self._MAX_REQUESTS_PER_MINUTE
if ratelimit_timeout > now:
self.logger.info(f'Ratelimit sleep: {(ratelimit_timeout-now):.2f}s')
await asyncio.sleep(ratelimit_timeout-now)
self._last_request_ts = time.time()
def _is_translation_invalid(self, query: str, trans: str) -> bool:
if not trans and query:
return True
if not query or not trans:
return False
query_symbols_count = len(set(query))
trans_symbols_count = len(set(trans))
if query_symbols_count > 6 and trans_symbols_count < 6 and trans_symbols_count < 0.25 * len(trans):
return True
return False
def _modify_invalid_translation_query(self, query: str, trans: str) -> str:
"""
Can be overwritten if _INVALID_REPEAT_COUNT was set. It modifies the query
for the next translation attempt.
"""
return query
def _clean_translation_output(self, query: str, trans: str, to_lang: str) -> str:
"""
Tries to spot and skim down invalid translations.
"""
if not query or not trans:
return ''
# ' ' -> ' '
trans = re.sub(r'\s+', r' ', trans)
# 'text.text' -> 'text. text'
trans = re.sub(r'(?<![.,;!?])([.,;!?])(?=\w)', r'\1 ', trans)
# ' ! ! . . ' -> ' !!.. '
trans = re.sub(r'([.,;!?])\s+(?=[.,;!?]|$)', r'\1', trans)
if to_lang != 'ARA':
# 'text .' -> 'text.'
trans = re.sub(r'(?<=[.,;!?\w])\s+([.,;!?])', r'\1', trans)
# ' ... text' -> ' ...text'
trans = re.sub(r'((?:\s|^)\.+)\s+(?=\w)', r'\1', trans)
seq = repeating_sequence(trans.lower())
# 'aaaaaaaaaaaaa' -> 'aaaaaa'
if len(trans) < len(query) and len(seq) < 0.5 * len(trans):
# Shrink sequence to length of original query
trans = seq * max(len(query) // len(seq), 1)
# Transfer capitalization of query to translation
nTrans = ''
for i in range(min(len(trans), len(query))):
nTrans += trans[i].upper() if query[i].isupper() else trans[i]
trans = nTrans
# words = text.split()
# elements = list(set(words))
# if len(elements) / len(words) < 0.1:
# words = words[:int(len(words) / 1.75)]
# text = ' '.join(words)
# # For words that appear more then four times consecutively, remove the excess
# for el in elements:
# el = re.escape(el)
# text = re.sub(r'(?: ' + el + r'){4} (' + el + r' )+', ' ', text)
return trans
class OfflineTranslator(CommonTranslator, ModelWrapper):
_MODEL_SUB_DIR = 'translators'
async def _translate(self, *args, **kwargs):
return await self.infer(*args, **kwargs)
@abstractmethod
async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
pass
async def load(self, from_lang: str, to_lang: str, device: str):
return await super().load(device, *self.parse_language_codes(from_lang, to_lang))
@abstractmethod
async def _load(self, from_lang: str, to_lang: str, device: str):
pass
async def reload(self, from_lang: str, to_lang: str, device: str):
return await super().reload(device, from_lang, to_lang)