Sunday01's picture
up
9dce458
import ctranslate2
import sentencepiece as spm
from typing import List
import re
from .common import OfflineTranslator
from ..utils import chunks
class JparacrawlTranslator(OfflineTranslator):
_LANGUAGE_CODE_MAP = {
'JPN': 'ja',
'ENG': 'en',
}
_CT2_MODEL_FOLDERS = {
'ja-en': 'jparacrawl/base-ja-en',
'en-ja': 'jparacrawl/base-en-ja',
}
_MODEL_MAPPING = {
'models': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/jparacrawl-base-models.zip',
'hash': 'e98e0fa35a80d2bc48c16673914639db66da1013ec66cc7b79119cdd3b542ebb',
'archive': {
'spm.ja.nopretok.model': 'jparacrawl/',
'spm.en.nopretok.model': 'jparacrawl/',
'base-ja-en': f'{_CT2_MODEL_FOLDERS["ja-en"]}',
'base-en-ja': f'{_CT2_MODEL_FOLDERS["en-ja"]}',
},
},
}
# def _on_download_finished(self, map_key):
# print('Converting downloaded models to ct2 format')
# self._convert_fairseq_models_to_ct2(
# self._get_file_path(self._FAIRSEQ_MODEL_FILES['ja-en']),
# self._get_file_path('jparacrawl'),
# self._get_file_path(self._CT2_MODEL_FOLDERS['ja-en']),
# 'ja', 'en',
# )
# self._convert_fairseq_models_to_ct2(
# self._get_file_path(self._FAIRSEQ_MODEL_FILES['en-ja']),
# self._get_file_path('jparacrawl'),
# self._get_file_path(self._CT2_MODEL_FOLDERS['en-ja']),
# 'en', 'ja',
# )
# # os.remove(self._get_file_path(self._MODEL_FILES['en-ja']))
# # os.remove(self._get_file_path(self._MODEL_FILES['ja-en']))
# def _convert_fairseq_models_to_ct2(self, model_path: str, data_dir: str, output_dir: str, from_lang: str, to_lang: str):
# cmds = [
# 'ct2-fairseq-converter',
# '--model_path', model_path,
# '--data_dir', data_dir,
# '--output_dir', output_dir,
# '--source_lang', from_lang,
# '--target_lang', to_lang,
# ]
# subprocess.check_call(cmds)
async def _load(self, from_lang: str, to_lang: str, device: str):
if from_lang == 'auto':
if to_lang == 'en':
from_lang = 'ja'
else:
from_lang = 'en'
self.load_params = {
'from_lang': from_lang,
'to_lang': to_lang,
'device': device,
}
self.model = ctranslate2.Translator(
model_path=self._get_file_path(self._CT2_MODEL_FOLDERS[f'{from_lang}-{to_lang}']),
device=device,
device_index=0,
)
self.model.load_model()
self.sentence_piece_processors = {
'en': spm.SentencePieceProcessor(model_file=self._get_file_path('jparacrawl/spm.en.nopretok.model')),
'ja': spm.SentencePieceProcessor(model_file=self._get_file_path('jparacrawl/spm.ja.nopretok.model')),
}
async def _unload(self):
self.model.unload_model()
del self.model
del self.sentence_piece_processors
async def infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
if from_lang == 'auto':
if to_lang == 'en':
from_lang = 'ja'
else:
from_lang = 'en'
if self.is_loaded() and to_lang != self.load_params['to_lang']:
await self.reload(self.load_params['device'], from_lang, to_lang)
return await super().infer(from_lang, to_lang, queries)
async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
queries_tokenized = self.tokenize(queries, from_lang)
translated_tokenized = self.model.translate_batch(
source=queries_tokenized,
beam_size=5,
num_hypotheses=1,
return_alternatives=False,
disable_unk=True,
replace_unknowns=True,
repetition_penalty=3,
)
translated = self.detokenize(list(map(lambda t: t[0]['tokens'], translated_tokenized)), to_lang)
return translated
def tokenize(self, queries, lang):
sp = self.sentence_piece_processors[lang]
if isinstance(queries, list):
return sp.encode(queries, out_type=str)
else:
return [sp.encode(queries, out_type=str)]
def detokenize(self, queries, lang):
sp = self.sentence_piece_processors[lang]
translation = sp.decode(queries)
return translation
class JparacrawlBigTranslator(JparacrawlTranslator):
_CT2_MODEL_FOLDERS = {
'ja-en': 'jparacrawl/big-ja-en',
'en-ja': 'jparacrawl/big-en-ja',
}
_MODEL_MAPPING = {
'models': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/jparacrawl-big-models.zip',
'hash': '5e0c4cea5a5098152f566de3694602ed3db52927d3df22d2a7bfb8dba2bebe33',
'archive': {
'spm.ja.nopretok.model': 'jparacrawl/',
'spm.en.nopretok.model': 'jparacrawl/',
'big-ja-en': f'{_CT2_MODEL_FOLDERS["ja-en"]}',
'big-en-ja': f'{_CT2_MODEL_FOLDERS["en-ja"]}',
},
},
}
class SugoiTranslator(JparacrawlBigTranslator):
"""
Sugoi model V4.0 for ja->en translation. For en->ja it falls back to jparacrawl.
"""
_CT2_MODEL_FOLDERS = {
'ja-en': 'sugoi/big-ja-en',
'en-ja': 'jparacrawl/big-en-ja',
}
_MODEL_MAPPING = {
**JparacrawlBigTranslator._MODEL_MAPPING,
'models-sugoi': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/sugoi-models.zip',
'hash': '67e060a62dc16211157a5eaa4fa8f72c86db5999fc69322606a6fcdf57f587f7',
'archive': {
'spm.ja.nopretok.model': 'sugoi/',
'spm.en.nopretok.model': 'sugoi/',
'big-ja-en': f'{_CT2_MODEL_FOLDERS["ja-en"]}',
},
},
}
def __init__(self):
self.query_split_sizes = []
super().__init__()
async def _load(self, from_lang: str, to_lang: str, device: str):
await super()._load(from_lang, to_lang, device)
self.sentence_piece_processors['en-sugoi'] = spm.SentencePieceProcessor(model_file=self._get_file_path('sugoi/spm.en.nopretok.model'))
self.sentence_piece_processors['ja-sugoi'] = spm.SentencePieceProcessor(model_file=self._get_file_path('sugoi/spm.ja.nopretok.model'))
def tokenize(self, queries, lang):
if lang == 'ja':
lang = 'ja-sugoi'
new_queries = []
self.query_split_sizes = []
for q in queries:
# Split sentences into their own queries to prevent abbreviations
sentences = re.split(r'(\w[.β€₯…!?。・]+)', q)
chunk_queries = []
# Two sentences per query
for chunk in chunks(sentences, 4):
s = ''.join(chunk)
chunk_queries.append(re.sub(r'[.。]', '@', s))
self.query_split_sizes.append(len(chunk_queries))
new_queries.extend(chunk_queries)
queries = new_queries
return super().tokenize(queries, lang)
def detokenize(self, queries, lang):
if lang == 'en':
lang = 'en-sugoi'
translations = super().detokenize(queries, lang)
if lang == 'en-sugoi':
new_translations = []
i = 0
# Put the split queries back together
for query_count in self.query_split_sizes:
sentences = ' '.join(translations[i:i+query_count])
i += query_count
sentences = sentences.replace('@', '.').replace('▁', ' ').replace('<unk>', '')
new_translations.append(sentences)
translations = new_translations
return translations