File size: 21,963 Bytes
9dce458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
import re
import os
from venv import logger

try:
    import openai
except ImportError:
    openai = None
import asyncio
from typing import List, Dict, Callable, Tuple

from .common import CommonTranslator
from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH

import logging


class SakuraDict():
    def __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None:
        self.logger = logger
        self.dict_str = ""
        self.version = version
        if not os.path.exists(path):
            if self.version == '0.10':
                self.logger.warning(f"字典文件不存在: {path}")
            return
        else:
            self.path = path
        if self.version == '0.10':
            self.dict_str = self.get_dict_from_file(path)
        if self.version == '0.9':
            self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表")

    def load_galtransl_dic(self, dic_path: str):
        """
        载入Galtransl词典。
        """

        with open(dic_path, encoding="utf8") as f:
            dic_lines = f.readlines()
        if len(dic_lines) == 0:
            return
        dic_path = os.path.abspath(dic_path)
        dic_name = os.path.basename(dic_path)
        normalDic_count = 0

        gpt_dict = []
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):  # 注释行跳过
                continue

            # 四个空格换成Tab
            line = line.replace("    ", "\t")

            sp = line.rstrip("\r\n").split("\t")  # 去多余换行符,Tab分割
            len_sp = len(sp)

            if len_sp < 2:  # 至少是2个元素
                continue

            src = sp[0]
            dst = sp[1]
            info = sp[2] if len_sp > 2 else None
            gpt_dict.append({"src": src, "dst": dst, "info": info})
            normalDic_count += 1

        gpt_dict_text_list = []
        for gpt in gpt_dict:
            src = gpt['src']
            dst = gpt['dst']
            info = gpt['info'] if "info" in gpt.keys() else None
            if info:
                single = f"{src}->{dst} #{info}"
            else:
                single = f"{src}->{dst}"
            gpt_dict_text_list.append(single)

        gpt_dict_raw_text = "\n".join(gpt_dict_text_list)
        self.dict_str = gpt_dict_raw_text
        self.logger.info(
            f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条"
        )

    def load_sakura_dict(self, dic_path: str):
        """
        直接载入标准的Sakura字典。
        """

        with open(dic_path, encoding="utf8") as f:
            dic_lines = f.readlines()

        if len(dic_lines) == 0:
            return
        dic_path = os.path.abspath(dic_path)
        dic_name = os.path.basename(dic_path)
        normalDic_count = 0

        gpt_dict_text_list = []
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):  # 注释行跳过
                continue

            sp = line.rstrip("\r\n").split("->")  # 去多余换行符,->分割
            len_sp = len(sp)

            if len_sp < 2:  # 至少是2个元素
                continue

            src = sp[0]
            dst_info = sp[1].split("#")  # 使用#分割目标和信息
            dst = dst_info[0].strip()
            info = dst_info[1].strip() if len(dst_info) > 1 else None
            if info:
                single = f"{src}->{dst} #{info}"
            else:
                single = f"{src}->{dst}"
            gpt_dict_text_list.append(single)
            normalDic_count += 1

        gpt_dict_raw_text = "\n".join(gpt_dict_text_list)
        self.dict_str = gpt_dict_raw_text
        self.logger.info(
            f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条"
        )

    def detect_type(self, dic_path: str):
        """
        检测字典类型。
        """
        with open(dic_path, encoding="utf8") as f:
            dic_lines = f.readlines()
        self.logger.debug(f"检测字典类型: {dic_path}")
        if len(dic_lines) == 0:
            return "unknown"

        # 判断是否为Galtransl字典
        is_galtransl = True
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):
                continue

            if "\t" not in line and "    " not in line:
                is_galtransl = False
                break

        if is_galtransl:
            return "galtransl"

        # 判断是否为Sakura字典
        is_sakura = True
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):
                continue

            if "->" not in line:
                is_sakura = False
                break

        if is_sakura:
            return "sakura"

        return "unknown"

    def get_dict_str(self):
        """
        获取字典内容。
        """
        if self.version == '0.9':
            self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表")
            return ""
        if self.dict_str == "":
            try:
                self.dict_str = self.get_dict_from_file(self.path)
                return self.dict_str
            except Exception as e:
                if self.version == '0.10':
                    self.logger.warning(f"载入字典失败: {e}")
                return ""
        return self.dict_str

    def get_dict_from_file(self, dic_path: str):
        """
        从文件载入字典。
        """
        dic_type = self.detect_type(dic_path)
        if dic_type == "galtransl":
            self.load_galtransl_dic(dic_path)
        elif dic_type == "sakura":
            self.load_sakura_dict(dic_path)
        else:
            self.logger.warning(f"未知的字典类型: {dic_path}")
        return self.get_dict_str()


class SakuraTranslator(CommonTranslator):

    _TIMEOUT = 999  # 等待服务器响应的超时时间(秒)
    _RETRY_ATTEMPTS = 3  # 请求出错时的重试次数
    _TIMEOUT_RETRY_ATTEMPTS = 3  # 请求超时时的重试次数
    _RATELIMIT_RETRY_ATTEMPTS = 3  # 请求被限速时的重试次数
    _REPEAT_DETECT_THRESHOLD = 20  # 重复检测的阈值

    _CHAT_SYSTEM_TEMPLATE_009 = (
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。'
    )
    _CHAT_SYSTEM_TEMPLATE_010 = (
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。'
    )

    _LANGUAGE_CODE_MAP = {
        'CHS': 'Simplified Chinese',
        'JPN': 'Japanese'
    }

    def __init__(self):
        super().__init__()
        self.client = openai.AsyncOpenAI()
        if "/v1" not in SAKURA_API_BASE:
            self.client.base_url = SAKURA_API_BASE + "/v1"
        else:
            self.client.base_url = SAKURA_API_BASE
        self.client.api_key = "sk-114514"
        self.temperature = 0.3
        self.top_p = 0.3
        self.frequency_penalty = 0.1
        self._current_style = "precise"
        self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]')
        self._heart_pattern = re.compile(r'❤')
        self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger, SAKURA_VERSION)

    def get_sakura_version(self):
        return SAKURA_VERSION

    def get_dict_path(self):
        return SAKURA_DICT_PATH

    def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_THRESHOLD, remove_all=True) -> Tuple[bool, str, int, str]:
        """
        检测文本中是否存在重复模式,并计算重复次数。
        返回值: (是否重复, 去除重复后的文本, 重复次数, 重复模式)
        """
        repeated = False
        counts = []
        for pattern_length in range(1, len(s) // 2 + 1):
            i = 0
            while i < len(s) - pattern_length:
                pattern = s[i:i + pattern_length]
                count = 1
                j = i + pattern_length
                while j <= len(s) - pattern_length:
                    if s[j:j + pattern_length] == pattern:
                        count += 1
                        j += pattern_length
                    else:
                        break
                counts.append(count)
                if count >= threshold:
                    self.logger.warning(f"检测到重复模式: {pattern},重复次数: {count}")
                    repeated = True
                    if remove_all:
                        s = s[:i + pattern_length] + s[j:]
                    break
                i += 1
            if repeated:
                break

        # 计算重复次数的众数
        if counts:
            mode_count = max(set(counts), key=counts.count)
        else:
            mode_count = 0

        # 根据默认阈值和众数计算实际阈值
        actual_threshold = max(threshold, mode_count)

        return repeated, s, count, pattern, actual_threshold

    @staticmethod
    def enlarge_small_kana(text, ignore=''):
        """将小写平假名或片假名转换为普通大小

        参数
        ----------
        text : str
            全角平假名或片假名字符串。
        ignore : str, 可选
            转换时要忽略的字符。

        返回
        ------
        str
            平假名或片假名字符串,小写假名已转换为大写

        示例
        --------
        >>> print(enlarge_small_kana('さくらきょうこ'))
        さくらきようこ
        >>> print(enlarge_small_kana('キュゥべえ'))
        キユウべえ
        """
        SMALL_KANA = list('ぁぃぅぇぉゃゅょっァィゥェォヵヶャュョッ')
        SMALL_KANA_NORMALIZED = list('あいうえおやゆよつアイウエオカケヤユヨツ')
        SMALL_KANA2BIG_KANA = dict(zip(map(ord, SMALL_KANA), SMALL_KANA_NORMALIZED))

        def _exclude_ignorechar(ignore, conv_map):
            for character in map(ord, ignore):
                del conv_map[character]
            return conv_map

        def _convert(text, conv_map):
            return text.translate(conv_map)

        def _translate(text, ignore, conv_map):
            if ignore:
                _conv_map = _exclude_ignorechar(ignore, conv_map.copy())
                return _convert(text, _conv_map)
            return _convert(text, conv_map)

        return _translate(text, ignore, SMALL_KANA2BIG_KANA)

    def _format_prompt_log(self, prompt: str) -> str:
        """
        格式化日志输出的提示文本。
        """
        gpt_dict_raw_text = self.sakura_dict.get_dict_str()
        prompt_009 = '\n'.join([
            'System:',
            self._CHAT_SYSTEM_TEMPLATE_009,
            'User:',
            '将下面的日文文本翻译成中文:',
            prompt,
        ])
        prompt_010 = '\n'.join([
            'System:',
            self._CHAT_SYSTEM_TEMPLATE_010,
            'User:',
            "根据以下术语表:",
            gpt_dict_raw_text,
            "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:",
            prompt,
        ])
        return prompt_009 if SAKURA_VERSION == '0.9' else prompt_010

    def _split_text(self, text: str) -> List[str]:
        """
        将字符串按换行符分割为列表。
        """
        if isinstance(text, list):
            return text
        return text.split('\n')

    def _preprocess_queries(self, queries: List[str]) -> List[str]:
        """
        预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。
        """
        queries = [self.enlarge_small_kana(query) for query in queries]
        queries = [self._emoji_pattern.sub('', query) for query in queries]
        queries = [self._heart_pattern.sub('♥', query) for query in queries]
        queries = [f'「{query}」' for query in queries]
        self.logger.debug(f'预处理后的查询文本:{queries}')
        return queries

    async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]:
        """
        检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。
        """
        async def _retry_translation(queries: List[str], check_func: Callable[[str], bool], error_message: str) -> str:
            styles = ["precise", "normal", "aggressive", ]
            for i in range(self._RETRY_ATTEMPTS):
                self._set_gpt_style(styles[i])
                self.logger.warning(f'{error_message} 尝试次数: {i + 1}。当前参数风格:{self._current_style}。')
                response = await self._handle_translation_request(queries)
                if not check_func(response):
                    return response
            return None

        # 检查请求内容是否含有超过默认阈值的重复内容
        if self._detect_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD):
            self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。')

        # 根据译文众数和默认阈值计算实际阈值
        actual_threshold = max(max(self._get_repeat_count(query) for query in queries), self._REPEAT_DETECT_THRESHOLD)

        if self._detect_repeats(response, actual_threshold):
            response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。')
            if response is None:
                self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。')
                return await self._translate_single_lines(queries)

        if not self._check_align(queries, response):
            response = await _retry_translation(queries, lambda r: not self._check_align(queries, r), '因为检测到原文与译文行数不匹配,重新翻译。')
            if response is None:
                self.logger.warning(f'原文与译文行数不匹配,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。')
                return await self._translate_single_lines(queries)

        return self._split_text(response)

    def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool:
        """
        检测文本中是否存在重复模式。
        """
        is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False)
        return is_repeated

    def _get_repeat_count(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool:
        """
        计算文本中重复模式的次数。
        """
        is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False)
        return count

    def _check_align(self, queries: List[str], response: str) -> bool:
        """
        检查原始文本和翻译结果的行数是否对齐。
        """
        translations = self._split_text(response)
        is_aligned = len(queries) == len(translations)
        if not is_aligned:
            self.logger.warning(f"行数不匹配 - 原文行数: {len(queries)},译文行数: {len(translations)}")
        return is_aligned

    async def _translate_single_lines(self, queries: List[str]) -> List[str]:
        """
        逐行翻译查询文本。
        """
        translations = []
        for query in queries:
            response = await self._handle_translation_request(query)
            if self._detect_repeats(response):
                self.logger.warning(f"单行翻译结果存在重复内容: {response},返回原文。")
                translations.append(query)
            else:
                translations.append(response)
        return translations

    def _delete_quotation_mark(self, texts: List[str]) -> List[str]:
        """
        删除文本中的「」标记。
        """
        new_texts = []
        for text in texts:
            text = text.strip('「」')
            new_texts.append(text)
        return new_texts

    async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
        self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}')
        self.logger.debug(f'原文: {queries}')
        text_prompt = '\n'.join(queries)
        self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n')

        # 预处理查询文本
        queries = self._preprocess_queries(queries)

        # 发送翻译请求
        response = await self._handle_translation_request(queries)
        self.logger.debug('-- Sakura Response --\n' + response + '\n\n')

        # 检查翻译结果是否存在重复或行数不匹配的问题
        translations = await self._check_translation_quality(queries, response)

        return self._delete_quotation_mark(translations)

    async def _handle_translation_request(self, prompt: str) -> str:
        """
        处理翻译请求,包括错误处理和重试逻辑。
        """
        ratelimit_attempt = 0
        server_error_attempt = 0
        timeout_attempt = 0
        while True:
            try:
                request_task = asyncio.create_task(self._request_translation(prompt))
                response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT)
                break
            except asyncio.TimeoutError:
                timeout_attempt += 1
                if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS:
                    raise Exception('Sakura超时。')
                self.logger.warning(f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}')
            except openai.RateLimitError:
                ratelimit_attempt += 1
                if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS:
                    raise
                self.logger.warning(f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}')
                await asyncio.sleep(2)
            except (openai.APIError, openai.APIConnectionError) as e:
                server_error_attempt += 1
                if server_error_attempt >= self._RETRY_ATTEMPTS:
                    self.logger.error(f'Sakura API请求失败。错误信息: {e}')
                    return prompt
                self.logger.warning(f'Sakura因服务器错误而进行重试。尝试次数: {server_error_attempt},错误信息: {e}')

        return response

    async def _request_translation(self, input_text_list) -> str:
        """
        向Sakura API发送翻译请求。
        """
        if isinstance(input_text_list, list):
            raw_text = "\n".join(input_text_list)
        else:
            raw_text = input_text_list
        raw_lenth = len(raw_text)
        max_lenth = 512
        max_token_num = max(raw_lenth*2, max_lenth)
        extra_query = {
            'do_sample': False,
            'num_beams': 1,
            'repetition_penalty': 1.0,
        }
        if SAKURA_VERSION == "0.9":
            messages = [
                {
                    "role": "system",
                    "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}"
                },
                {
                    "role": "user",
                    "content": f"将下面的日文文本翻译成中文:{raw_text}"
                }
            ]
        else:
            gpt_dict_raw_text = self.sakura_dict.get_dict_str()
            self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}")
            messages = [
                {
                    "role": "system",
                    "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}"
                },
                {
                    "role": "user",
                    "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}"
                }
            ]
        response = await self.client.chat.completions.create(
            model="sukinishiro",
            messages=messages,
            temperature=self.temperature,
            top_p=self.top_p,
            max_tokens=max_token_num,
            frequency_penalty=self.frequency_penalty,
            seed=-1,
            extra_query=extra_query,
        )
        # 提取并返回响应文本
        for choice in response.choices:
            if 'text' in choice:
                return choice.text

        return response.choices[0].message.content

    def _set_gpt_style(self, style_name: str):
        """
        设置GPT的生成风格。
        """
        if self._current_style == style_name:
            return
        self._current_style = style_name
        if style_name == "precise":
            temperature, top_p = 0.1, 0.3
            frequency_penalty = 0.05
        elif style_name == "normal":
            temperature, top_p = 0.3, 0.3
            frequency_penalty = 0.2
        elif style_name == "aggressive":
            temperature, top_p = 0.3, 0.3
            frequency_penalty = 0.3

        self.temperature = temperature
        self.top_p = top_p
        self.frequency_penalty = frequency_penalty