test2023h5 commited on
Commit
ee98572
·
verified ·
1 Parent(s): 5bba451

Create mymodule.py

Browse files
Files changed (1) hide show
  1. mymodule.py +62 -0
mymodule.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #调用大模型
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel, get_peft_config
4
+ import json
5
+ import torch
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+
9
+ # 加载预训练模型
10
+ model_name = "Qwen/Qwen2-0.5B"
11
+ base_model = AutoModelForCausalLM.from_pretrained(model_name)
12
+
13
+ # 加载适配器
14
+ adapter_path1 = "test2023h5/wyw2xdw"
15
+ adapter_path2 = "test2023h5/xdw2wyw"
16
+
17
+ # 加载适配器
18
+ base_model.load_adapter(adapter_path1, adapter_name='adapter1')
19
+ base_model.load_adapter(adapter_path2, adapter_name='adapter2')
20
+
21
+ base_model.set_adapter("adapter1")
22
+ #base_model.set_adapter("adapter2")
23
+
24
+ model = base_model.to(device)
25
+
26
+ # 加载 tokenizer
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+
29
+ print("model loading done")
30
+
31
+ def format_instruction(task, text):
32
+ string = f"""### 指令:
33
+ {task}
34
+
35
+ ### 输入:
36
+ {text}
37
+
38
+ ### 输出:
39
+ """
40
+ return string
41
+
42
+ def generate_response(task, text):
43
+ input_text = format_instruction(task, text)
44
+ encoding = tokenizer(input_text, return_tensors="pt").to(device)
45
+ with torch.no_grad(): # 禁用梯度计算
46
+ outputs = model.generate(**encoding, max_new_tokens=50)
47
+ generated_ids = outputs[:, encoding.input_ids.shape[1]:]
48
+ generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
49
+ return generated_texts[0].split('\n')[0]
50
+
51
+ def predict(text, method):
52
+ if method == 0:
53
+ prompt = ["翻译成现代文", text]
54
+ base_model.set_adapter("adapter1")
55
+ else:
56
+ prompt = ["翻译成古文", text]
57
+ base_model.set_adapter("adapter2")
58
+
59
+ print("debug", text)
60
+ response = generate_response(prompt[0], prompt[1])
61
+ print("debug2", response)
62
+ return response