Update train.py
Browse files
train.py
CHANGED
@@ -5,6 +5,7 @@ import json
|
|
5 |
import os
|
6 |
import random
|
7 |
import re
|
|
|
8 |
|
9 |
class ModelTrainer:
|
10 |
def __init__(self, model_id, system_prompts_path):
|
@@ -14,16 +15,19 @@ class ModelTrainer:
|
|
14 |
with open(system_prompts_path, 'r', encoding='utf-8') as f:
|
15 |
self.system_prompts = json.load(f)
|
16 |
|
17 |
-
#
|
18 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
19 |
model_id,
|
20 |
-
trust_remote_code=True
|
21 |
)
|
|
|
|
|
22 |
self.model = AutoModelForCausalLM.from_pretrained(
|
23 |
model_id,
|
24 |
-
trust_remote_code=True,
|
25 |
-
|
26 |
-
|
|
|
27 |
)
|
28 |
|
29 |
# 使用更轻量的LoRA配置
|
|
|
5 |
import os
|
6 |
import random
|
7 |
import re
|
8 |
+
import torch
|
9 |
|
10 |
class ModelTrainer:
|
11 |
def __init__(self, model_id, system_prompts_path):
|
|
|
15 |
with open(system_prompts_path, 'r', encoding='utf-8') as f:
|
16 |
self.system_prompts = json.load(f)
|
17 |
|
18 |
+
# 修改模型初始化参数
|
19 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
20 |
model_id,
|
21 |
+
trust_remote_code=True
|
22 |
)
|
23 |
+
|
24 |
+
# 修改这部分的初始化参数
|
25 |
self.model = AutoModelForCausalLM.from_pretrained(
|
26 |
model_id,
|
27 |
+
trust_remote_code=True,
|
28 |
+
torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
|
29 |
+
device_map='auto', # 自动选择设备
|
30 |
+
low_cpu_mem_usage=True
|
31 |
)
|
32 |
|
33 |
# 使用更轻量的LoRA配置
|