|
import re |
|
import os |
|
from venv import logger |
|
|
|
try: |
|
import openai |
|
except ImportError: |
|
openai = None |
|
import asyncio |
|
from typing import List, Dict, Callable, Tuple |
|
|
|
from .common import CommonTranslator |
|
from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH |
|
|
|
import logging |
|
|
|
|
|
class SakuraDict(): |
|
def __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None: |
|
self.logger = logger |
|
self.dict_str = "" |
|
self.version = version |
|
if not os.path.exists(path): |
|
if self.version == '0.10': |
|
self.logger.warning(f"字典文件不存在: {path}") |
|
return |
|
else: |
|
self.path = path |
|
if self.version == '0.10': |
|
self.dict_str = self.get_dict_from_file(path) |
|
if self.version == '0.9': |
|
self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") |
|
|
|
def load_galtransl_dic(self, dic_path: str): |
|
""" |
|
载入Galtransl词典。 |
|
""" |
|
|
|
with open(dic_path, encoding="utf8") as f: |
|
dic_lines = f.readlines() |
|
if len(dic_lines) == 0: |
|
return |
|
dic_path = os.path.abspath(dic_path) |
|
dic_name = os.path.basename(dic_path) |
|
normalDic_count = 0 |
|
|
|
gpt_dict = [] |
|
for line in dic_lines: |
|
if line.startswith("\n"): |
|
continue |
|
elif line.startswith("\\\\") or line.startswith("//"): |
|
continue |
|
|
|
|
|
line = line.replace(" ", "\t") |
|
|
|
sp = line.rstrip("\r\n").split("\t") |
|
len_sp = len(sp) |
|
|
|
if len_sp < 2: |
|
continue |
|
|
|
src = sp[0] |
|
dst = sp[1] |
|
info = sp[2] if len_sp > 2 else None |
|
gpt_dict.append({"src": src, "dst": dst, "info": info}) |
|
normalDic_count += 1 |
|
|
|
gpt_dict_text_list = [] |
|
for gpt in gpt_dict: |
|
src = gpt['src'] |
|
dst = gpt['dst'] |
|
info = gpt['info'] if "info" in gpt.keys() else None |
|
if info: |
|
single = f"{src}->{dst} #{info}" |
|
else: |
|
single = f"{src}->{dst}" |
|
gpt_dict_text_list.append(single) |
|
|
|
gpt_dict_raw_text = "\n".join(gpt_dict_text_list) |
|
self.dict_str = gpt_dict_raw_text |
|
self.logger.info( |
|
f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" |
|
) |
|
|
|
def load_sakura_dict(self, dic_path: str): |
|
""" |
|
直接载入标准的Sakura字典。 |
|
""" |
|
|
|
with open(dic_path, encoding="utf8") as f: |
|
dic_lines = f.readlines() |
|
|
|
if len(dic_lines) == 0: |
|
return |
|
dic_path = os.path.abspath(dic_path) |
|
dic_name = os.path.basename(dic_path) |
|
normalDic_count = 0 |
|
|
|
gpt_dict_text_list = [] |
|
for line in dic_lines: |
|
if line.startswith("\n"): |
|
continue |
|
elif line.startswith("\\\\") or line.startswith("//"): |
|
continue |
|
|
|
sp = line.rstrip("\r\n").split("->") |
|
len_sp = len(sp) |
|
|
|
if len_sp < 2: |
|
continue |
|
|
|
src = sp[0] |
|
dst_info = sp[1].split("#") |
|
dst = dst_info[0].strip() |
|
info = dst_info[1].strip() if len(dst_info) > 1 else None |
|
if info: |
|
single = f"{src}->{dst} #{info}" |
|
else: |
|
single = f"{src}->{dst}" |
|
gpt_dict_text_list.append(single) |
|
normalDic_count += 1 |
|
|
|
gpt_dict_raw_text = "\n".join(gpt_dict_text_list) |
|
self.dict_str = gpt_dict_raw_text |
|
self.logger.info( |
|
f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" |
|
) |
|
|
|
def detect_type(self, dic_path: str): |
|
""" |
|
检测字典类型。 |
|
""" |
|
with open(dic_path, encoding="utf8") as f: |
|
dic_lines = f.readlines() |
|
self.logger.debug(f"检测字典类型: {dic_path}") |
|
if len(dic_lines) == 0: |
|
return "unknown" |
|
|
|
|
|
is_galtransl = True |
|
for line in dic_lines: |
|
if line.startswith("\n"): |
|
continue |
|
elif line.startswith("\\\\") or line.startswith("//"): |
|
continue |
|
|
|
if "\t" not in line and " " not in line: |
|
is_galtransl = False |
|
break |
|
|
|
if is_galtransl: |
|
return "galtransl" |
|
|
|
|
|
is_sakura = True |
|
for line in dic_lines: |
|
if line.startswith("\n"): |
|
continue |
|
elif line.startswith("\\\\") or line.startswith("//"): |
|
continue |
|
|
|
if "->" not in line: |
|
is_sakura = False |
|
break |
|
|
|
if is_sakura: |
|
return "sakura" |
|
|
|
return "unknown" |
|
|
|
def get_dict_str(self): |
|
""" |
|
获取字典内容。 |
|
""" |
|
if self.version == '0.9': |
|
self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") |
|
return "" |
|
if self.dict_str == "": |
|
try: |
|
self.dict_str = self.get_dict_from_file(self.path) |
|
return self.dict_str |
|
except Exception as e: |
|
if self.version == '0.10': |
|
self.logger.warning(f"载入字典失败: {e}") |
|
return "" |
|
return self.dict_str |
|
|
|
def get_dict_from_file(self, dic_path: str): |
|
""" |
|
从文件载入字典。 |
|
""" |
|
dic_type = self.detect_type(dic_path) |
|
if dic_type == "galtransl": |
|
self.load_galtransl_dic(dic_path) |
|
elif dic_type == "sakura": |
|
self.load_sakura_dict(dic_path) |
|
else: |
|
self.logger.warning(f"未知的字典类型: {dic_path}") |
|
return self.get_dict_str() |
|
|
|
|
|
class SakuraTranslator(CommonTranslator): |
|
|
|
_TIMEOUT = 999 |
|
_RETRY_ATTEMPTS = 3 |
|
_TIMEOUT_RETRY_ATTEMPTS = 3 |
|
_RATELIMIT_RETRY_ATTEMPTS = 3 |
|
_REPEAT_DETECT_THRESHOLD = 20 |
|
|
|
_CHAT_SYSTEM_TEMPLATE_009 = ( |
|
'你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' |
|
) |
|
_CHAT_SYSTEM_TEMPLATE_010 = ( |
|
'你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。' |
|
) |
|
|
|
_LANGUAGE_CODE_MAP = { |
|
'CHS': 'Simplified Chinese', |
|
'JPN': 'Japanese' |
|
} |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.client = openai.AsyncOpenAI() |
|
if "/v1" not in SAKURA_API_BASE: |
|
self.client.base_url = SAKURA_API_BASE + "/v1" |
|
else: |
|
self.client.base_url = SAKURA_API_BASE |
|
self.client.api_key = "sk-114514" |
|
self.temperature = 0.3 |
|
self.top_p = 0.3 |
|
self.frequency_penalty = 0.1 |
|
self._current_style = "precise" |
|
self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') |
|
self._heart_pattern = re.compile(r'❤') |
|
self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger, SAKURA_VERSION) |
|
|
|
def get_sakura_version(self): |
|
return SAKURA_VERSION |
|
|
|
def get_dict_path(self): |
|
return SAKURA_DICT_PATH |
|
|
|
def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_THRESHOLD, remove_all=True) -> Tuple[bool, str, int, str]: |
|
""" |
|
检测文本中是否存在重复模式,并计算重复次数。 |
|
返回值: (是否重复, 去除重复后的文本, 重复次数, 重复模式) |
|
""" |
|
repeated = False |
|
counts = [] |
|
for pattern_length in range(1, len(s) // 2 + 1): |
|
i = 0 |
|
while i < len(s) - pattern_length: |
|
pattern = s[i:i + pattern_length] |
|
count = 1 |
|
j = i + pattern_length |
|
while j <= len(s) - pattern_length: |
|
if s[j:j + pattern_length] == pattern: |
|
count += 1 |
|
j += pattern_length |
|
else: |
|
break |
|
counts.append(count) |
|
if count >= threshold: |
|
self.logger.warning(f"检测到重复模式: {pattern},重复次数: {count}") |
|
repeated = True |
|
if remove_all: |
|
s = s[:i + pattern_length] + s[j:] |
|
break |
|
i += 1 |
|
if repeated: |
|
break |
|
|
|
|
|
if counts: |
|
mode_count = max(set(counts), key=counts.count) |
|
else: |
|
mode_count = 0 |
|
|
|
|
|
actual_threshold = max(threshold, mode_count) |
|
|
|
return repeated, s, count, pattern, actual_threshold |
|
|
|
@staticmethod |
|
def enlarge_small_kana(text, ignore=''): |
|
"""将小写平假名或片假名转换为普通大小 |
|
|
|
参数 |
|
---------- |
|
text : str |
|
全角平假名或片假名字符串。 |
|
ignore : str, 可选 |
|
转换时要忽略的字符。 |
|
|
|
返回 |
|
------ |
|
str |
|
平假名或片假名字符串,小写假名已转换为大写 |
|
|
|
示例 |
|
-------- |
|
>>> print(enlarge_small_kana('さくらきょうこ')) |
|
さくらきようこ |
|
>>> print(enlarge_small_kana('キュゥべえ')) |
|
キユウべえ |
|
""" |
|
SMALL_KANA = list('ぁぃぅぇぉゃゅょっァィゥェォヵヶャュョッ') |
|
SMALL_KANA_NORMALIZED = list('あいうえおやゆよつアイウエオカケヤユヨツ') |
|
SMALL_KANA2BIG_KANA = dict(zip(map(ord, SMALL_KANA), SMALL_KANA_NORMALIZED)) |
|
|
|
def _exclude_ignorechar(ignore, conv_map): |
|
for character in map(ord, ignore): |
|
del conv_map[character] |
|
return conv_map |
|
|
|
def _convert(text, conv_map): |
|
return text.translate(conv_map) |
|
|
|
def _translate(text, ignore, conv_map): |
|
if ignore: |
|
_conv_map = _exclude_ignorechar(ignore, conv_map.copy()) |
|
return _convert(text, _conv_map) |
|
return _convert(text, conv_map) |
|
|
|
return _translate(text, ignore, SMALL_KANA2BIG_KANA) |
|
|
|
def _format_prompt_log(self, prompt: str) -> str: |
|
""" |
|
格式化日志输出的提示文本。 |
|
""" |
|
gpt_dict_raw_text = self.sakura_dict.get_dict_str() |
|
prompt_009 = '\n'.join([ |
|
'System:', |
|
self._CHAT_SYSTEM_TEMPLATE_009, |
|
'User:', |
|
'将下面的日文文本翻译成中文:', |
|
prompt, |
|
]) |
|
prompt_010 = '\n'.join([ |
|
'System:', |
|
self._CHAT_SYSTEM_TEMPLATE_010, |
|
'User:', |
|
"根据以下术语表:", |
|
gpt_dict_raw_text, |
|
"将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:", |
|
prompt, |
|
]) |
|
return prompt_009 if SAKURA_VERSION == '0.9' else prompt_010 |
|
|
|
def _split_text(self, text: str) -> List[str]: |
|
""" |
|
将字符串按换行符分割为列表。 |
|
""" |
|
if isinstance(text, list): |
|
return text |
|
return text.split('\n') |
|
|
|
def _preprocess_queries(self, queries: List[str]) -> List[str]: |
|
""" |
|
预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 |
|
""" |
|
queries = [self.enlarge_small_kana(query) for query in queries] |
|
queries = [self._emoji_pattern.sub('', query) for query in queries] |
|
queries = [self._heart_pattern.sub('♥', query) for query in queries] |
|
queries = [f'「{query}」' for query in queries] |
|
self.logger.debug(f'预处理后的查询文本:{queries}') |
|
return queries |
|
|
|
async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: |
|
""" |
|
检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 |
|
""" |
|
async def _retry_translation(queries: List[str], check_func: Callable[[str], bool], error_message: str) -> str: |
|
styles = ["precise", "normal", "aggressive", ] |
|
for i in range(self._RETRY_ATTEMPTS): |
|
self._set_gpt_style(styles[i]) |
|
self.logger.warning(f'{error_message} 尝试次数: {i + 1}。当前参数风格:{self._current_style}。') |
|
response = await self._handle_translation_request(queries) |
|
if not check_func(response): |
|
return response |
|
return None |
|
|
|
|
|
if self._detect_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD): |
|
self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。') |
|
|
|
|
|
actual_threshold = max(max(self._get_repeat_count(query) for query in queries), self._REPEAT_DETECT_THRESHOLD) |
|
|
|
if self._detect_repeats(response, actual_threshold): |
|
response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。') |
|
if response is None: |
|
self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') |
|
return await self._translate_single_lines(queries) |
|
|
|
if not self._check_align(queries, response): |
|
response = await _retry_translation(queries, lambda r: not self._check_align(queries, r), '因为检测到原文与译文行数不匹配,重新翻译。') |
|
if response is None: |
|
self.logger.warning(f'原文与译文行数不匹配,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') |
|
return await self._translate_single_lines(queries) |
|
|
|
return self._split_text(response) |
|
|
|
def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: |
|
""" |
|
检测文本中是否存在重复模式。 |
|
""" |
|
is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) |
|
return is_repeated |
|
|
|
def _get_repeat_count(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: |
|
""" |
|
计算文本中重复模式的次数。 |
|
""" |
|
is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) |
|
return count |
|
|
|
def _check_align(self, queries: List[str], response: str) -> bool: |
|
""" |
|
检查原始文本和翻译结果的行数是否对齐。 |
|
""" |
|
translations = self._split_text(response) |
|
is_aligned = len(queries) == len(translations) |
|
if not is_aligned: |
|
self.logger.warning(f"行数不匹配 - 原文行数: {len(queries)},译文行数: {len(translations)}") |
|
return is_aligned |
|
|
|
async def _translate_single_lines(self, queries: List[str]) -> List[str]: |
|
""" |
|
逐行翻译查询文本。 |
|
""" |
|
translations = [] |
|
for query in queries: |
|
response = await self._handle_translation_request(query) |
|
if self._detect_repeats(response): |
|
self.logger.warning(f"单行翻译结果存在重复内容: {response},返回原文。") |
|
translations.append(query) |
|
else: |
|
translations.append(response) |
|
return translations |
|
|
|
def _delete_quotation_mark(self, texts: List[str]) -> List[str]: |
|
""" |
|
删除文本中的「」标记。 |
|
""" |
|
new_texts = [] |
|
for text in texts: |
|
text = text.strip('「」') |
|
new_texts.append(text) |
|
return new_texts |
|
|
|
async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: |
|
self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') |
|
self.logger.debug(f'原文: {queries}') |
|
text_prompt = '\n'.join(queries) |
|
self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') |
|
|
|
|
|
queries = self._preprocess_queries(queries) |
|
|
|
|
|
response = await self._handle_translation_request(queries) |
|
self.logger.debug('-- Sakura Response --\n' + response + '\n\n') |
|
|
|
|
|
translations = await self._check_translation_quality(queries, response) |
|
|
|
return self._delete_quotation_mark(translations) |
|
|
|
async def _handle_translation_request(self, prompt: str) -> str: |
|
""" |
|
处理翻译请求,包括错误处理和重试逻辑。 |
|
""" |
|
ratelimit_attempt = 0 |
|
server_error_attempt = 0 |
|
timeout_attempt = 0 |
|
while True: |
|
try: |
|
request_task = asyncio.create_task(self._request_translation(prompt)) |
|
response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT) |
|
break |
|
except asyncio.TimeoutError: |
|
timeout_attempt += 1 |
|
if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: |
|
raise Exception('Sakura超时。') |
|
self.logger.warning(f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}') |
|
except openai.RateLimitError: |
|
ratelimit_attempt += 1 |
|
if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: |
|
raise |
|
self.logger.warning(f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}') |
|
await asyncio.sleep(2) |
|
except (openai.APIError, openai.APIConnectionError) as e: |
|
server_error_attempt += 1 |
|
if server_error_attempt >= self._RETRY_ATTEMPTS: |
|
self.logger.error(f'Sakura API请求失败。错误信息: {e}') |
|
return prompt |
|
self.logger.warning(f'Sakura因服务器错误而进行重试。尝试次数: {server_error_attempt},错误信息: {e}') |
|
|
|
return response |
|
|
|
async def _request_translation(self, input_text_list) -> str: |
|
""" |
|
向Sakura API发送翻译请求。 |
|
""" |
|
if isinstance(input_text_list, list): |
|
raw_text = "\n".join(input_text_list) |
|
else: |
|
raw_text = input_text_list |
|
raw_lenth = len(raw_text) |
|
max_lenth = 512 |
|
max_token_num = max(raw_lenth*2, max_lenth) |
|
extra_query = { |
|
'do_sample': False, |
|
'num_beams': 1, |
|
'repetition_penalty': 1.0, |
|
} |
|
if SAKURA_VERSION == "0.9": |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"将下面的日文文本翻译成中文:{raw_text}" |
|
} |
|
] |
|
else: |
|
gpt_dict_raw_text = self.sakura_dict.get_dict_str() |
|
self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" |
|
} |
|
] |
|
response = await self.client.chat.completions.create( |
|
model="sukinishiro", |
|
messages=messages, |
|
temperature=self.temperature, |
|
top_p=self.top_p, |
|
max_tokens=max_token_num, |
|
frequency_penalty=self.frequency_penalty, |
|
seed=-1, |
|
extra_query=extra_query, |
|
) |
|
|
|
for choice in response.choices: |
|
if 'text' in choice: |
|
return choice.text |
|
|
|
return response.choices[0].message.content |
|
|
|
def _set_gpt_style(self, style_name: str): |
|
""" |
|
设置GPT的生成风格。 |
|
""" |
|
if self._current_style == style_name: |
|
return |
|
self._current_style = style_name |
|
if style_name == "precise": |
|
temperature, top_p = 0.1, 0.3 |
|
frequency_penalty = 0.05 |
|
elif style_name == "normal": |
|
temperature, top_p = 0.3, 0.3 |
|
frequency_penalty = 0.2 |
|
elif style_name == "aggressive": |
|
temperature, top_p = 0.3, 0.3 |
|
frequency_penalty = 0.3 |
|
|
|
self.temperature = temperature |
|
self.top_p = top_p |
|
self.frequency_penalty = frequency_penalty |
|
|