jljiu commited on
Commit
616370c
·
verified ·
1 Parent(s): e0d0ec8

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +9 -5
train.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import os
6
  import random
7
  import re
 
8
 
9
  class ModelTrainer:
10
  def __init__(self, model_id, system_prompts_path):
@@ -14,16 +15,19 @@ class ModelTrainer:
14
  with open(system_prompts_path, 'r', encoding='utf-8') as f:
15
  self.system_prompts = json.load(f)
16
 
17
- # 初始化tokenizer和model - 添加trust_remote_code=True
18
  self.tokenizer = AutoTokenizer.from_pretrained(
19
  model_id,
20
- trust_remote_code=True # 添加此参数
21
  )
 
 
22
  self.model = AutoModelForCausalLM.from_pretrained(
23
  model_id,
24
- trust_remote_code=True, # 添加此参数
25
- low_cpu_mem_usage=True,
26
- torch_dtype='float32'
 
27
  )
28
 
29
  # 使用更轻量的LoRA配置
 
5
  import os
6
  import random
7
  import re
8
+ import torch
9
 
10
  class ModelTrainer:
11
  def __init__(self, model_id, system_prompts_path):
 
15
  with open(system_prompts_path, 'r', encoding='utf-8') as f:
16
  self.system_prompts = json.load(f)
17
 
18
+ # 修改模型初始化参数
19
  self.tokenizer = AutoTokenizer.from_pretrained(
20
  model_id,
21
+ trust_remote_code=True
22
  )
23
+
24
+ # 修改这部分的初始化参数
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
29
+ device_map='auto', # 自动选择设备
30
+ low_cpu_mem_usage=True
31
  )
32
 
33
  # 使用更轻量的LoRA配置