import re import os from venv import logger try: import openai except ImportError: openai = None import asyncio from typing import List, Dict, Callable, Tuple from .common import CommonTranslator from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH import logging class SakuraDict(): def __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None: self.logger = logger self.dict_str = "" self.version = version if not os.path.exists(path): if self.version == '0.10': self.logger.warning(f"字典文件不存在: {path}") return else: self.path = path if self.version == '0.10': self.dict_str = self.get_dict_from_file(path) if self.version == '0.9': self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") def load_galtransl_dic(self, dic_path: str): """ 载入Galtransl词典。 """ with open(dic_path, encoding="utf8") as f: dic_lines = f.readlines() if len(dic_lines) == 0: return dic_path = os.path.abspath(dic_path) dic_name = os.path.basename(dic_path) normalDic_count = 0 gpt_dict = [] for line in dic_lines: if line.startswith("\n"): continue elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 continue # 四个空格换成Tab line = line.replace(" ", "\t") sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 len_sp = len(sp) if len_sp < 2: # 至少是2个元素 continue src = sp[0] dst = sp[1] info = sp[2] if len_sp > 2 else None gpt_dict.append({"src": src, "dst": dst, "info": info}) normalDic_count += 1 gpt_dict_text_list = [] for gpt in gpt_dict: src = gpt['src'] dst = gpt['dst'] info = gpt['info'] if "info" in gpt.keys() else None if info: single = f"{src}->{dst} #{info}" else: single = f"{src}->{dst}" gpt_dict_text_list.append(single) gpt_dict_raw_text = "\n".join(gpt_dict_text_list) self.dict_str = gpt_dict_raw_text self.logger.info( f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" ) def load_sakura_dict(self, dic_path: str): """ 直接载入标准的Sakura字典。 """ with open(dic_path, encoding="utf8") as f: dic_lines = f.readlines() if len(dic_lines) == 0: return dic_path = os.path.abspath(dic_path) dic_name = os.path.basename(dic_path) normalDic_count = 0 gpt_dict_text_list = [] for line in dic_lines: if line.startswith("\n"): continue elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 continue sp = line.rstrip("\r\n").split("->") # 去多余换行符,->分割 len_sp = len(sp) if len_sp < 2: # 至少是2个元素 continue src = sp[0] dst_info = sp[1].split("#") # 使用#分割目标和信息 dst = dst_info[0].strip() info = dst_info[1].strip() if len(dst_info) > 1 else None if info: single = f"{src}->{dst} #{info}" else: single = f"{src}->{dst}" gpt_dict_text_list.append(single) normalDic_count += 1 gpt_dict_raw_text = "\n".join(gpt_dict_text_list) self.dict_str = gpt_dict_raw_text self.logger.info( f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" ) def detect_type(self, dic_path: str): """ 检测字典类型。 """ with open(dic_path, encoding="utf8") as f: dic_lines = f.readlines() self.logger.debug(f"检测字典类型: {dic_path}") if len(dic_lines) == 0: return "unknown" # 判断是否为Galtransl字典 is_galtransl = True for line in dic_lines: if line.startswith("\n"): continue elif line.startswith("\\\\") or line.startswith("//"): continue if "\t" not in line and " " not in line: is_galtransl = False break if is_galtransl: return "galtransl" # 判断是否为Sakura字典 is_sakura = True for line in dic_lines: if line.startswith("\n"): continue elif line.startswith("\\\\") or line.startswith("//"): continue if "->" not in line: is_sakura = False break if is_sakura: return "sakura" return "unknown" def get_dict_str(self): """ 获取字典内容。 """ if self.version == '0.9': self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") return "" if self.dict_str == "": try: self.dict_str = self.get_dict_from_file(self.path) return self.dict_str except Exception as e: if self.version == '0.10': self.logger.warning(f"载入字典失败: {e}") return "" return self.dict_str def get_dict_from_file(self, dic_path: str): """ 从文件载入字典。 """ dic_type = self.detect_type(dic_path) if dic_type == "galtransl": self.load_galtransl_dic(dic_path) elif dic_type == "sakura": self.load_sakura_dict(dic_path) else: self.logger.warning(f"未知的字典类型: {dic_path}") return self.get_dict_str() class SakuraTranslator(CommonTranslator): _TIMEOUT = 999 # 等待服务器响应的超时时间(秒) _RETRY_ATTEMPTS = 3 # 请求出错时的重试次数 _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 _REPEAT_DETECT_THRESHOLD = 20 # 重复检测的阈值 _CHAT_SYSTEM_TEMPLATE_009 = ( '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' ) _CHAT_SYSTEM_TEMPLATE_010 = ( '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。' ) _LANGUAGE_CODE_MAP = { 'CHS': 'Simplified Chinese', 'JPN': 'Japanese' } def __init__(self): super().__init__() self.client = openai.AsyncOpenAI() if "/v1" not in SAKURA_API_BASE: self.client.base_url = SAKURA_API_BASE + "/v1" else: self.client.base_url = SAKURA_API_BASE self.client.api_key = "sk-114514" self.temperature = 0.3 self.top_p = 0.3 self.frequency_penalty = 0.1 self._current_style = "precise" self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') self._heart_pattern = re.compile(r'❤') self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger, SAKURA_VERSION) def get_sakura_version(self): return SAKURA_VERSION def get_dict_path(self): return SAKURA_DICT_PATH def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_THRESHOLD, remove_all=True) -> Tuple[bool, str, int, str]: """ 检测文本中是否存在重复模式,并计算重复次数。 返回值: (是否重复, 去除重复后的文本, 重复次数, 重复模式) """ repeated = False counts = [] for pattern_length in range(1, len(s) // 2 + 1): i = 0 while i < len(s) - pattern_length: pattern = s[i:i + pattern_length] count = 1 j = i + pattern_length while j <= len(s) - pattern_length: if s[j:j + pattern_length] == pattern: count += 1 j += pattern_length else: break counts.append(count) if count >= threshold: self.logger.warning(f"检测到重复模式: {pattern},重复次数: {count}") repeated = True if remove_all: s = s[:i + pattern_length] + s[j:] break i += 1 if repeated: break # 计算重复次数的众数 if counts: mode_count = max(set(counts), key=counts.count) else: mode_count = 0 # 根据默认阈值和众数计算实际阈值 actual_threshold = max(threshold, mode_count) return repeated, s, count, pattern, actual_threshold @staticmethod def enlarge_small_kana(text, ignore=''): """将小写平假名或片假名转换为普通大小 参数 ---------- text : str 全角平假名或片假名字符串。 ignore : str, 可选 转换时要忽略的字符。 返回 ------ str 平假名或片假名字符串,小写假名已转换为大写 示例 -------- >>> print(enlarge_small_kana('さくらきょうこ')) さくらきようこ >>> print(enlarge_small_kana('キュゥべえ')) キユウべえ """ SMALL_KANA = list('ぁぃぅぇぉゃゅょっァィゥェォヵヶャュョッ') SMALL_KANA_NORMALIZED = list('あいうえおやゆよつアイウエオカケヤユヨツ') SMALL_KANA2BIG_KANA = dict(zip(map(ord, SMALL_KANA), SMALL_KANA_NORMALIZED)) def _exclude_ignorechar(ignore, conv_map): for character in map(ord, ignore): del conv_map[character] return conv_map def _convert(text, conv_map): return text.translate(conv_map) def _translate(text, ignore, conv_map): if ignore: _conv_map = _exclude_ignorechar(ignore, conv_map.copy()) return _convert(text, _conv_map) return _convert(text, conv_map) return _translate(text, ignore, SMALL_KANA2BIG_KANA) def _format_prompt_log(self, prompt: str) -> str: """ 格式化日志输出的提示文本。 """ gpt_dict_raw_text = self.sakura_dict.get_dict_str() prompt_009 = '\n'.join([ 'System:', self._CHAT_SYSTEM_TEMPLATE_009, 'User:', '将下面的日文文本翻译成中文:', prompt, ]) prompt_010 = '\n'.join([ 'System:', self._CHAT_SYSTEM_TEMPLATE_010, 'User:', "根据以下术语表:", gpt_dict_raw_text, "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:", prompt, ]) return prompt_009 if SAKURA_VERSION == '0.9' else prompt_010 def _split_text(self, text: str) -> List[str]: """ 将字符串按换行符分割为列表。 """ if isinstance(text, list): return text return text.split('\n') def _preprocess_queries(self, queries: List[str]) -> List[str]: """ 预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 """ queries = [self.enlarge_small_kana(query) for query in queries] queries = [self._emoji_pattern.sub('', query) for query in queries] queries = [self._heart_pattern.sub('♥', query) for query in queries] queries = [f'「{query}」' for query in queries] self.logger.debug(f'预处理后的查询文本:{queries}') return queries async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: """ 检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 """ async def _retry_translation(queries: List[str], check_func: Callable[[str], bool], error_message: str) -> str: styles = ["precise", "normal", "aggressive", ] for i in range(self._RETRY_ATTEMPTS): self._set_gpt_style(styles[i]) self.logger.warning(f'{error_message} 尝试次数: {i + 1}。当前参数风格:{self._current_style}。') response = await self._handle_translation_request(queries) if not check_func(response): return response return None # 检查请求内容是否含有超过默认阈值的重复内容 if self._detect_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD): self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。') # 根据译文众数和默认阈值计算实际阈值 actual_threshold = max(max(self._get_repeat_count(query) for query in queries), self._REPEAT_DETECT_THRESHOLD) if self._detect_repeats(response, actual_threshold): response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。') if response is None: self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') return await self._translate_single_lines(queries) if not self._check_align(queries, response): response = await _retry_translation(queries, lambda r: not self._check_align(queries, r), '因为检测到原文与译文行数不匹配,重新翻译。') if response is None: self.logger.warning(f'原文与译文行数不匹配,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') return await self._translate_single_lines(queries) return self._split_text(response) def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: """ 检测文本中是否存在重复模式。 """ is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) return is_repeated def _get_repeat_count(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: """ 计算文本中重复模式的次数。 """ is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) return count def _check_align(self, queries: List[str], response: str) -> bool: """ 检查原始文本和翻译结果的行数是否对齐。 """ translations = self._split_text(response) is_aligned = len(queries) == len(translations) if not is_aligned: self.logger.warning(f"行数不匹配 - 原文行数: {len(queries)},译文行数: {len(translations)}") return is_aligned async def _translate_single_lines(self, queries: List[str]) -> List[str]: """ 逐行翻译查询文本。 """ translations = [] for query in queries: response = await self._handle_translation_request(query) if self._detect_repeats(response): self.logger.warning(f"单行翻译结果存在重复内容: {response},返回原文。") translations.append(query) else: translations.append(response) return translations def _delete_quotation_mark(self, texts: List[str]) -> List[str]: """ 删除文本中的「」标记。 """ new_texts = [] for text in texts: text = text.strip('「」') new_texts.append(text) return new_texts async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') self.logger.debug(f'原文: {queries}') text_prompt = '\n'.join(queries) self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') # 预处理查询文本 queries = self._preprocess_queries(queries) # 发送翻译请求 response = await self._handle_translation_request(queries) self.logger.debug('-- Sakura Response --\n' + response + '\n\n') # 检查翻译结果是否存在重复或行数不匹配的问题 translations = await self._check_translation_quality(queries, response) return self._delete_quotation_mark(translations) async def _handle_translation_request(self, prompt: str) -> str: """ 处理翻译请求,包括错误处理和重试逻辑。 """ ratelimit_attempt = 0 server_error_attempt = 0 timeout_attempt = 0 while True: try: request_task = asyncio.create_task(self._request_translation(prompt)) response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT) break except asyncio.TimeoutError: timeout_attempt += 1 if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: raise Exception('Sakura超时。') self.logger.warning(f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}') except openai.RateLimitError: ratelimit_attempt += 1 if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: raise self.logger.warning(f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}') await asyncio.sleep(2) except (openai.APIError, openai.APIConnectionError) as e: server_error_attempt += 1 if server_error_attempt >= self._RETRY_ATTEMPTS: self.logger.error(f'Sakura API请求失败。错误信息: {e}') return prompt self.logger.warning(f'Sakura因服务器错误而进行重试。尝试次数: {server_error_attempt},错误信息: {e}') return response async def _request_translation(self, input_text_list) -> str: """ 向Sakura API发送翻译请求。 """ if isinstance(input_text_list, list): raw_text = "\n".join(input_text_list) else: raw_text = input_text_list raw_lenth = len(raw_text) max_lenth = 512 max_token_num = max(raw_lenth*2, max_lenth) extra_query = { 'do_sample': False, 'num_beams': 1, 'repetition_penalty': 1.0, } if SAKURA_VERSION == "0.9": messages = [ { "role": "system", "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" }, { "role": "user", "content": f"将下面的日文文本翻译成中文:{raw_text}" } ] else: gpt_dict_raw_text = self.sakura_dict.get_dict_str() self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") messages = [ { "role": "system", "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" }, { "role": "user", "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" } ] response = await self.client.chat.completions.create( model="sukinishiro", messages=messages, temperature=self.temperature, top_p=self.top_p, max_tokens=max_token_num, frequency_penalty=self.frequency_penalty, seed=-1, extra_query=extra_query, ) # 提取并返回响应文本 for choice in response.choices: if 'text' in choice: return choice.text return response.choices[0].message.content def _set_gpt_style(self, style_name: str): """ 设置GPT的生成风格。 """ if self._current_style == style_name: return self._current_style = style_name if style_name == "precise": temperature, top_p = 0.1, 0.3 frequency_penalty = 0.05 elif style_name == "normal": temperature, top_p = 0.3, 0.3 frequency_penalty = 0.2 elif style_name == "aggressive": temperature, top_p = 0.3, 0.3 frequency_penalty = 0.3 self.temperature = temperature self.top_p = top_p self.frequency_penalty = frequency_penalty