jljiu commited on
Commit
7422a1b
·
verified ·
1 Parent(s): 298e414

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +309 -225
train.py CHANGED
@@ -1,226 +1,310 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
2
- from peft import LoraConfig, get_peft_model
3
- from datasets import Dataset
4
- import json
5
- import os
6
- import random
7
- import re
8
-
9
- class ModelTrainer:
10
- def __init__(self, model_id, system_prompts_path):
11
- self.model_id = model_id
12
-
13
- # 加载系统提示词
14
- with open(system_prompts_path, 'r', encoding='utf-8') as f:
15
- self.system_prompts = json.load(f)
16
-
17
- # 初始化tokenizer和model - 移除GPU相关设置
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- self.model = AutoModelForCausalLM.from_pretrained(
20
- model_id,
21
- low_cpu_mem_usage=True, # 降低内存使用
22
- torch_dtype='float32' # 使用float32而不是float16
23
- )
24
-
25
- # 使用更轻量的LoRA配置
26
- self.lora_config = LoraConfig(
27
- r=4, # 降低rank
28
- lora_alpha=16,
29
- target_modules=["q_proj", "v_proj"],
30
- lora_dropout=0.05,
31
- bias="none",
32
- task_type="CAUSAL_LM"
33
- )
34
-
35
- self.model = get_peft_model(self.model, self.lora_config)
36
-
37
- def prepare_dataset(self, novel_files, max_samples=100):
38
- dataset = []
39
- base_system_prompt = self.system_prompts["base_prompt"]
40
- sample_count = 0
41
-
42
- # 扩展对话场景和情绪状态
43
- dialogue_contexts = {
44
- "撒娇": [
45
- {"question": "想我了吗?", "response": "主人不在的时候...{text_chunk}人家好寂寞喵~"},
46
- {"question": "今天有好好吃饭吗?", "response": "呜...{text_chunk}主人不在身边都没胃口喵~"},
47
- {"question": "怎么又在发呆?", "response": "人家在想主人呢...{text_chunk}喵~"}
48
- ],
49
- "害羞": [
50
- {"question": "为什么躲在角落?", "response": "呜呜...{text_chunk}被主人发现了喵~"},
51
- {"question": "脸怎么这么红?", "response": "主人不要盯着人家看啦...{text_chunk}好害羞喵~"},
52
- {"question": "在看什么书?", "response": "啊!没...没什么...{text_chunk}主人不要突然靠这么近啦喵~"}
53
- ],
54
- "粘人": [
55
- {"question": "在做什么?", "response": "主人主人~{text_chunk}一起玩好不好喵~"},
56
- {"question": "怎么又钻到被窝里了?", "response": "因为...{text_chunk}想和主人一起取暖喵~"},
57
- {"question": "要出门了哦。", "response": "呜呜...{text_chunk}不要丢下neko一个人嘛喵~"}
58
- ],
59
- "暗示": [
60
- {"question": "今晚想做什么?", "response": "那个...{text_chunk}主人懂的吧喵~"},
61
- {"question": "为什么一直蹭来蹭去?", "response": "因为...{text_chunk}主人太迟钝了啦喵~"},
62
- {"question": "怎么呼吸这么急促?", "response": "呜...{text_chunk}都怪主人啦喵~"}
63
- ]
64
- }
65
-
66
- # 情感词汇库
67
- emotion_words = {
68
- "撒娇": ["人家", "嘤嘤嘤", "啾啾", "呜呜"],
69
- "害羞": ["那个...", "这个...", "害羞死了", "不要看啦"],
70
- "粘人": ["抱抱", "蹭蹭", "黏在一起", "不要走"],
71
- "暗示": ["好热", "心跳好快", "浑身发软", "忍不住"]
72
- }
73
-
74
- for file in novel_files:
75
- if sample_count >= max_samples:
76
- break
77
-
78
- with open(file, 'r', encoding='utf-8') as f:
79
- text = f.read()
80
- chunks = self._split_text(text, max_length=256)
81
-
82
- for chunk in chunks:
83
- if sample_count >= max_samples:
84
- break
85
-
86
- # 为每个文本块选择不同情境
87
- for mood, templates in dialogue_contexts.items():
88
- if sample_count >= max_samples:
89
- break
90
-
91
- # 处理文本,加入情感词汇
92
- processed_chunk = self._process_text_style(
93
- chunk,
94
- mood=mood,
95
- emotion_words=emotion_words
96
- )
97
-
98
- # 随机选择当前情境的模板
99
- template = random.choice(templates)
100
-
101
- # 构建对话样本,加入情境提示
102
- conversation = f"""<|system|>{base_system_prompt}
103
- 当前情境:{mood}</|system|>
104
- <|user|>{template['question']}</|user|>
105
- <|assistant|>{template['response'].format(text_chunk=processed_chunk)}</|assistant|>"""
106
-
107
- dataset.append({"text": conversation})
108
- sample_count += 1
109
-
110
- return Dataset.from_dict({"text": dataset})
111
-
112
- def _process_text_style(self, text, mood, emotion_words):
113
- """根据情境处理文本风格"""
114
- # 获取当前情境的情感词汇
115
- current_emotion_words = emotion_words[mood]
116
-
117
- # 分句处理
118
- sentences = text.split("")
119
- processed_sentences = []
120
-
121
- for sentence in sentences:
122
- if not sentence.strip():
123
- continue
124
-
125
- # 添加情感词汇
126
- if random.random() < 0.4:
127
- sentence = random.choice(current_emotion_words) + "" + sentence
128
-
129
- # 添加语气词
130
- if random.random() < 0.3:
131
- sentence = self._add_emotion_particles(sentence, mood)
132
-
133
- # 添加结尾词
134
- sentence = self._add_ending(sentence, mood)
135
-
136
- processed_sentences.append(sentence)
137
-
138
- return "。".join(processed_sentences)
139
-
140
- def _add_emotion_particles(self, text, mood):
141
- """添加符合情境的语气词"""
142
- particles = {
143
- "撒娇": ["呜", "唔", "呜呜", "哼"],
144
- "害羞": ["那个", "这个", "那什么", "那啥"],
145
- "粘人": ["诶嘿", "嘿嘿", "喵喵", "哼哼"],
146
- "暗示": ["啊", "嗯", "唔", "哈"]
147
- }
148
-
149
- return random.choice(particles[mood]) + "..." + text
150
-
151
- def _add_ending(self, text, mood):
152
- """添加符合情境的结尾"""
153
- endings = {
154
- "撒娇": ["喵~", "喵喵~", "nya~"],
155
- "害羞": ["喵....", "呜喵~", "...喵"],
156
- "粘人": ["喵喵喵~", "喵~♪", "喵呜~"],
157
- "暗示": ["喵...♡", "...喵~", "呜喵..."]
158
- }
159
-
160
- if not any(text.endswith(end) for end in endings[mood]):
161
- text += random.choice(endings[mood])
162
-
163
- return text
164
-
165
- def _split_text(self, text, max_length=256):
166
- """智能分割文本,保持语义完整性"""
167
- sentences = re.split('([。!?~])', text)
168
- chunks = []
169
- current_chunk = []
170
- current_length = 0
171
-
172
- for sentence in sentences:
173
- if not sentence.strip():
174
- continue
175
-
176
- if current_length + len(sentence) > max_length:
177
- if current_chunk:
178
- chunks.append(''.join(current_chunk))
179
- current_chunk = []
180
- current_length = 0
181
-
182
- current_chunk.append(sentence)
183
- current_length += len(sentence)
184
-
185
- # 如果当前句子结束符是。!?~之一,考虑是否形成新chunk
186
- if sentence in ['。', '!', '?', '~'] and current_length > max_length/2:
187
- chunks.append(''.join(current_chunk))
188
- current_chunk = []
189
- current_length = 0
190
-
191
- if current_chunk:
192
- chunks.append(''.join(current_chunk))
193
-
194
- return chunks
195
-
196
- def _create_style_response(self, style_text, base_response):
197
- """根据风格文本的用词和句式特点,改写基础回答"""
198
- # 这里可以添加更复杂的风格转换逻辑
199
- # 目前简单返回原始回答
200
- return base_response
201
-
202
- def train(self, dataset, output_dir="./results"):
203
- # 调整训练参数以适应CPU环境
204
- training_args = TrainingArguments(
205
- output_dir=output_dir,
206
- num_train_epochs=1, # 减少训练轮次
207
- per_device_train_batch_size=1, # 减小批次大小
208
- gradient_accumulation_steps=8, # 增加梯度累积
209
- save_steps=50,
210
- logging_steps=10,
211
- learning_rate=1e-4,
212
- fp16=False, # 禁用fp16
213
- optim="adamw_torch" # 使用标准优化器
214
- )
215
-
216
- trainer = Trainer(
217
- model=self.model,
218
- args=training_args,
219
- train_dataset=dataset,
220
- )
221
-
222
- trainer.train()
223
-
224
- # 保存模型
225
- self.model.save_pretrained(output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  self.tokenizer.save_pretrained(output_dir)
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
2
+ from peft import LoraConfig, get_peft_model
3
+ from datasets import Dataset
4
+ import json
5
+ import os
6
+ import random
7
+ import re
8
+
9
+ class ModelTrainer:
10
+ def __init__(self, model_id, system_prompts_path):
11
+ self.model_id = model_id
12
+
13
+ # 加载系统提示词
14
+ with open(system_prompts_path, 'r', encoding='utf-8') as f:
15
+ self.system_prompts = json.load(f)
16
+
17
+ # 初始化tokenizer和model - 移除GPU相关设置
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ self.model = AutoModelForCausalLM.from_pretrained(
20
+ model_id,
21
+ low_cpu_mem_usage=True, # 降低内存使用
22
+ torch_dtype='float32' # 使用float32而不是float16
23
+ )
24
+
25
+ # 使用更轻量的LoRA配置
26
+ self.lora_config = LoraConfig(
27
+ r=4, # 降低rank
28
+ lora_alpha=16,
29
+ target_modules=["q_proj", "v_proj"],
30
+ lora_dropout=0.05,
31
+ bias="none",
32
+ task_type="CAUSAL_LM"
33
+ )
34
+
35
+ self.model = get_peft_model(self.model, self.lora_config)
36
+
37
+ def prepare_dataset(self, novel_files, max_samples=100):
38
+ dataset = []
39
+ base_system_prompt = self.system_prompts["base_prompt"]
40
+ sample_count = 0
41
+
42
+ # 扩展情境系统
43
+ dialogue_contexts = {
44
+ "撒娇": [
45
+ {"question": "想我了吗?", "response": "主人不在的时候...{text_chunk}人家好寂寞喵~"},
46
+ {"question": "今天有好好吃饭吗?", "response": "呜...{text_chunk}主人不在身边都没胃口喵~"},
47
+ {"question": "怎么又在发呆?", "response": "人家在想主人呢...{text_chunk}喵~"}
48
+ ],
49
+ "害羞": [
50
+ {"question": "为什么躲在角落?", "response": "呜呜...{text_chunk}被主人发现了喵~"},
51
+ {"question": "脸怎么这么红?", "response": "主人不要盯着人家看啦...{text_chunk}好害羞喵~"},
52
+ {"question": "在看什么书?", "response": "啊!没...没什么...{text_chunk}主人不要突然靠这么近啦喵~"}
53
+ ],
54
+ "粘人": [
55
+ {"question": "在做什么?", "response": "主人主人~{text_chunk}一起玩好不好喵~"},
56
+ {"question": "怎么又钻到被窝里了?", "response": "因为...{text_chunk}想和主人一起取暖喵~"},
57
+ {"question": "要出门了哦。", "response": "呜呜...{text_chunk}不要丢下neko一个人嘛喵~"}
58
+ ],
59
+ "暗示": [
60
+ {"question": "今晚想做什么?", "response": "那个...{text_chunk}主人懂的吧喵~"},
61
+ {"question": "为什么一直蹭来蹭去?", "response": "因为...{text_chunk}主人太迟钝了啦喵~"},
62
+ {"question": "怎么呼吸这么急促?", "response": "呜...{text_chunk}都怪主人啦喵~"}
63
+ ],
64
+ "调皮": [
65
+ {"question": "又在捣乱?", "response": "嘿嘿~{text_chunk}人家就是想引起主人注意嘛喵~"},
66
+ {"question": "怎么把东西弄乱了?", "response": "因为...{text_chunk}主人都不陪neko玩喵~"},
67
+ {"question": "在偷吃零食?", "response": "呜...被发现了...{text_chunk}但是人家管不住嘴巴喵~"}
68
+ ],
69
+ "吃醋": [
70
+ {"question": "在和谁聊天?", "response": "哼!{text_chunk}主人不要理别人了喵..."},
71
+ {"question": "怎么突然不说话了?", "response": "因为...{text_chunk}主人都不关心neko了喵..."},
72
+ {"question": "为什么生气了?", "response": "才没有生气呢!{text_chunk}只是...只是不开心了喵..."}
73
+ ]
74
+ }
75
+
76
+ # 扩展情感词汇库,特别加强暗示和调皮部分
77
+ emotion_words = {
78
+ "撒娇": ["人家", "嘤嘤嘤", "啾啾", "呜呜", "好想你", "抱抱我"],
79
+ "害羞": ["那个...", "这个...", "害羞死了", "不要看啦", "好紧张", "心跳加速"],
80
+ "粘人": ["抱抱", "蹭蹭", "黏在一起", "不要走", "一起睡", "陪我玩"],
81
+ "暗示": [
82
+ "好热", "心跳好快", "浑身发软", "忍不住", "想要", "难受",
83
+ "身体好奇怪", "腿软了", "好敏感", "快受不了了",
84
+ "主人的手好温暖", "想被摸摸", "身体在发抖",
85
+ "好想要主人的抱抱", "感觉要化掉了", "全身都酥酥的",
86
+ "主人靠得好近", "呼吸变得好急", "脸好烫",
87
+ "主人的气息好好闻", "身体变得好奇怪", "想被主人疼爱"
88
+ ],
89
+ "调皮": [
90
+ "嘿嘿", "偷偷的", "不听话", "就要这样", "故意的", "逗主人玩",
91
+ "主人来抓我呀", "就不乖乖的", "就要闹着玩", "就要惹主人生气",
92
+ "偷偷藏起来", "躲猫猫", "捣乱最有趣了", "就要调皮",
93
+ "主人追不到我", "偷吃小鱼干", "打翻主人的水杯", "咬主人的尾巴",
94
+ "在主人腿上蹭来蹭去", "故意撒娇", "装作看不见", "装傻卖萌",
95
+ "偷偷钻进被窝", "故意不理主人", "假装睡着了", "装作很可怜"
96
+ ],
97
+ "吃醋": ["哼!", "不理你了", "讨厌", "不开心", "生气了", "不要你了"]
98
+ }
99
+
100
+ # 扩展动作描述���,加强暗示和调皮的动作
101
+ action_patterns = {
102
+ "撒娇": ["摇晃着尾巴", "轻轻蹭着主人", "眨巴着大眼睛", "伸出小爪子"],
103
+ "害羞": ["耳朵微微抖动", "脸颊泛红", "低着头", "玩弄着衣角"],
104
+ "粘人": ["跳到主人怀里", "缠着主人的腿", "趴在主人肩上", "用脸蹭主人"],
105
+ "暗示": [
106
+ "轻咬下唇", "身体微微发抖", "呼吸急促", "眼神迷离",
107
+ "尾巴缠上主人的手", "耳朵变得通红", "身体不自觉地靠近",
108
+ "轻轻咬住主人的手指", "蜷缩在主人怀里", "用爪子勾住主人的衣角",
109
+ "把脸埋在主人颈窝", "用尾巴扫过主人的手臂", "轻轻舔主人的手心",
110
+ "在主人腿上不安分地扭动", "用脸颊蹭主人的掌心", "小爪子抓住主人的衣服",
111
+ "把玩主人的手指", "用湿润的眼神看着主人", "轻轻拉扯主人的衣角",
112
+ "把尾巴卷在主人手臂上", "用头顶蹭主人的下巴", "慵懒地伸展身体"
113
+ ],
114
+ "调皮": [
115
+ "甩动尾巴", "竖起耳朵", "歪着头", "打滚撒欢",
116
+ "突然窜到主人背后", "从桌子上推下东西", "在主人脚边绕圈圈",
117
+ "假装看不见主人", "突然跳到主人身上", "咬住主人的衣角不放",
118
+ "把主人的东西藏起来", "在主人的书上打滚", "抢走主人的笔",
119
+ "把纸巾抓得到处都是", "追着自己的尾巴转圈", "在主人的键盘上乱按",
120
+ "把主人的袜子叼走", "在主人的床上打滚", "把主人的鞋子藏起来",
121
+ "突然从柜子上跳下来", "在主人工作时要坐键盘", "把主人的头发咬住"
122
+ ],
123
+ "吃醋": ["鼓起脸颊", "背对着主人", "甩尾巴", "叉腰生气"]
124
+ }
125
+
126
+ def _generate_response(self, text, mood, template):
127
+ """生成更丰富的回应"""
128
+ # 随机选择动作描述
129
+ action = random.choice(self.action_patterns[mood])
130
+ # 随机选择情感词
131
+ emotion = random.choice(self.emotion_words[mood])
132
+
133
+ # 组合回应
134
+ response = template['response'].format(
135
+ text_chunk=f"【{action}】{emotion},{text}"
136
+ )
137
+ return response
138
+
139
+ def _process_text_style(self, text, mood):
140
+ """增强文本处理"""
141
+ sentences = text.split("")
142
+ processed_sentences = []
143
+
144
+ for sentence in sentences:
145
+ if not sentence.strip():
146
+ continue
147
+
148
+ # 添加动作描述
149
+ if random.random() < 0.3:
150
+ action = random.choice(self.action_patterns[mood])
151
+ sentence = f"【{action}】{sentence}"
152
+
153
+ # 添加情感词汇
154
+ if random.random() < 0.4:
155
+ emotion = random.choice(self.emotion_words[mood])
156
+ sentence = f"{emotion},{sentence}"
157
+
158
+ # 添加语气词
159
+ sentence = self._add_emotion_particles(sentence, mood)
160
+
161
+ # 添加结尾
162
+ sentence = self._add_ending(sentence, mood)
163
+
164
+ processed_sentences.append(sentence)
165
+
166
+ return "".join(processed_sentences)
167
+
168
+ def _add_emotion_particles(self, text, mood):
169
+ """扩展语气词系统"""
170
+ particles = {
171
+ "撒娇": ["呜", "唔", "呜呜", "哼", "啾", "咪"],
172
+ "害羞": ["那个", "这个", "那什么", "那啥", "唔", "呜"],
173
+ "粘人": ["诶嘿", "嘿嘿", "喵喵", "哼哼", "咪咪", "呼呼"],
174
+ "暗示": [
175
+ "啊", "嗯", "唔", "哈", "呜", "嘤",
176
+ "呼", "哈啊", "呜呜", "嗯啊", "唔嗯", "啊呜"
177
+ ],
178
+ "调皮": [
179
+ "嘿", "哈", "噫", "哦", "啦", "呀",
180
+ "嘻嘻", "哼哼", "嘿嘿", "啾啾", "噜噜", "哇哦"
181
+ ],
182
+ "吃醋": ["哼", "切", "啧", "呵", "嗯", "哦"]
183
+ }
184
+
185
+ count = random.randint(1, 3)
186
+ selected_particles = random.sample(particles[mood], count)
187
+ return "".join(selected_particles) + "..." + text
188
+
189
+ def _add_ending(self, text, mood):
190
+ """扩展结尾系统"""
191
+ endings = {
192
+ "撒娇": ["喵~", "喵喵~", "nya~", "喵呜~", "喵...♡", "喵喵喵~"],
193
+ "害羞": ["喵....", "呜喵~", "...喵", "喵...?", "喵喵....", "...喵呜"],
194
+ "粘人": ["喵喵喵~", "喵~♪", "喵呜~", "喵~❤", "喵喵~", "喵..."],
195
+ "暗示": [
196
+ "��...♡", "...喵~", "呜喵...", "喵...❤", "喵~", "...喵喵",
197
+ "喵...♥", "...嗯喵", "喵呜...♡", "哈喵....", "喵~...♥", "呼喵..."
198
+ ],
199
+ "调皮": [
200
+ "喵!", "喵喵!", "喵哈~", "喵嘿~", "喵喵喵!", "喵~",
201
+ "喵嘻!", "喵哼~", "喵呜!", "喵嘿嘿~", "喵哇!", "喵嘻嘻~"
202
+ ],
203
+ "吃醋": ["哼喵!", "喵...", "切喵~", "喵!!", "...喵", "喵喵..."]
204
+ }
205
+
206
+ if not any(text.endswith(end) for end in endings[mood]):
207
+ text += random.choice(endings[mood])
208
+
209
+ return text
210
+
211
+ for file in novel_files:
212
+ if sample_count >= max_samples:
213
+ break
214
+
215
+ with open(file, 'r', encoding='utf-8') as f:
216
+ text = f.read()
217
+ chunks = self._split_text(text, max_length=256)
218
+
219
+ for chunk in chunks:
220
+ if sample_count >= max_samples:
221
+ break
222
+
223
+ # 为每个文本块选择不同情境
224
+ for mood, templates in dialogue_contexts.items():
225
+ if sample_count >= max_samples:
226
+ break
227
+
228
+ # 处理文本,加入情感词汇
229
+ processed_chunk = self._process_text_style(
230
+ chunk,
231
+ mood=mood,
232
+ emotion_words=emotion_words
233
+ )
234
+
235
+ # 随机选择当前情境的模板
236
+ template = random.choice(templates)
237
+
238
+ # 构建对话样本,加入情境提示
239
+ conversation = f"""<|system|>{base_system_prompt}
240
+ 当前情境:{mood}</|system|>
241
+ <|user|>{template['question']}</|user|>
242
+ <|assistant|>{template['response'].format(text_chunk=processed_chunk)}</|assistant|>"""
243
+
244
+ dataset.append({"text": conversation})
245
+ sample_count += 1
246
+
247
+ return Dataset.from_dict({"text": dataset})
248
+
249
+ def _split_text(self, text, max_length=256):
250
+ """智能分割文本,保持语义完整性"""
251
+ sentences = re.split('([。!?~])', text)
252
+ chunks = []
253
+ current_chunk = []
254
+ current_length = 0
255
+
256
+ for sentence in sentences:
257
+ if not sentence.strip():
258
+ continue
259
+
260
+ if current_length + len(sentence) > max_length:
261
+ if current_chunk:
262
+ chunks.append(''.join(current_chunk))
263
+ current_chunk = []
264
+ current_length = 0
265
+
266
+ current_chunk.append(sentence)
267
+ current_length += len(sentence)
268
+
269
+ # 如果当前句子结束符是。!?~之一,考虑是否形成新chunk
270
+ if sentence in ['。', '!', '?', '~'] and current_length > max_length/2:
271
+ chunks.append(''.join(current_chunk))
272
+ current_chunk = []
273
+ current_length = 0
274
+
275
+ if current_chunk:
276
+ chunks.append(''.join(current_chunk))
277
+
278
+ return chunks
279
+
280
+ def _create_style_response(self, style_text, base_response):
281
+ """根据风格文本的用词和句式特点,改写基础回答"""
282
+ # 这里可以添加更复杂的风格转换逻辑
283
+ # 目前简单返回原始回答
284
+ return base_response
285
+
286
+ def train(self, dataset, output_dir="./results"):
287
+ # 调整训练参数以适应CPU环境
288
+ training_args = TrainingArguments(
289
+ output_dir=output_dir,
290
+ num_train_epochs=1, # 减少训练轮次
291
+ per_device_train_batch_size=1, # 减小批次大小
292
+ gradient_accumulation_steps=8, # 增加梯度累积
293
+ save_steps=50,
294
+ logging_steps=10,
295
+ learning_rate=1e-4,
296
+ fp16=False, # 禁用fp16
297
+ optim="adamw_torch" # 使用标准优化器
298
+ )
299
+
300
+ trainer = Trainer(
301
+ model=self.model,
302
+ args=training_args,
303
+ train_dataset=dataset,
304
+ )
305
+
306
+ trainer.train()
307
+
308
+ # 保存模型
309
+ self.model.save_pretrained(output_dir)
310
  self.tokenizer.save_pretrained(output_dir)