StarRing2022 commited on
Commit
dd875b0
·
1 Parent(s): c2117ba

Upload generate_hf.py

Browse files
Files changed (1) hide show
  1. generate_hf.py +72 -0
generate_hf.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ import transformers
4
+ import gradio as gr
5
+
6
+ from ringrwkv.configuration_rwkv_world import RwkvConfig
7
+ from ringrwkv.rwkv_tokenizer import TRIE_TOKENIZER
8
+ from ringrwkv.modehf_world import RwkvForCausalLM
9
+
10
+
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ else:
14
+ device = "cpu"
15
+
16
+ #放在本地工程根目录文件夹
17
+
18
+
19
+ model = RwkvForCausalLM.from_pretrained("RWKV-4-World-1.5B")
20
+ tokenizer = TRIE_TOKENIZER('./ringrwkv/rwkv_vocab_v20230424.txt')
21
+
22
+
23
+ #model= PeftModel.from_pretrained(model, "./lora-out")
24
+ model = model.to(device)
25
+
26
+
27
+ def evaluate(
28
+ instruction,
29
+ temperature=1,
30
+ top_p=0.7,
31
+ top_k = 0.1,
32
+ penalty_alpha = 0.1,
33
+ max_new_tokens=128,
34
+ ):
35
+
36
+ prompt = f'Question: {instruction.strip()}\n\nAnswer:'
37
+ input_ids = tokenizer.encode(prompt)
38
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
39
+ #out = model.generate(input_ids=input_ids.to(device),max_new_tokens=40)
40
+ out = model.generate(input_ids=input_ids.to(device),temperature=temperature,top_p=top_p,top_k=top_k,penalty_alpha=penalty_alpha,max_new_tokens=max_new_tokens)
41
+ outlist = out[0].tolist()
42
+ for i in outlist:
43
+ if i==0:
44
+ outlist.remove(i)
45
+ answer = tokenizer.decode(outlist)
46
+ return answer.strip()
47
+ #return answer.split("### Response:")[1].strip()
48
+
49
+
50
+ gr.Interface(
51
+ fn=evaluate,#接口函数
52
+ inputs=[
53
+ gr.components.Textbox(
54
+ lines=2, label="Instruction", placeholder="Tell me about alpacas."
55
+ ),
56
+ gr.components.Slider(minimum=0, maximum=2, value=1, label="Temperature"),
57
+ gr.components.Slider(minimum=0, maximum=1, value=0.7, label="Top p"),
58
+ gr.components.Slider(minimum=0, maximum=1, step=1, value=0.1, label="top_k"),
59
+ gr.components.Slider(minimum=0, maximum=1, step=1, value=0.1, label="penalty_alpha"),
60
+ gr.components.Slider(
61
+ minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
62
+ ),
63
+ ],
64
+ outputs=[
65
+ gr.inputs.Textbox(
66
+ lines=5,
67
+ label="Output",
68
+ )
69
+ ],
70
+ title="RWKV-World-Alpaca",
71
+ description="RWKV,Easy In HF.",
72
+ ).launch()