Nelsonlin0321
commited on
Commit
·
3b98af5
1
Parent(s):
5724f4c
Upload handler.py
Browse files- handler.py +48 -0
handler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
|
3 |
+
from peft import PeftModel
|
4 |
+
|
5 |
+
|
6 |
+
class EndpointHandler():
|
7 |
+
def __init__(self):
|
8 |
+
|
9 |
+
self.tokenizer = LlamaTokenizer.from_pretrained(
|
10 |
+
"decapoda-research/llama-7b-hf")
|
11 |
+
self.tokenizer.pad_token_id = 0
|
12 |
+
self.model = LlamaForCausalLM.from_pretrained(
|
13 |
+
"decapoda-research/llama-7b-hf",
|
14 |
+
load_in_8bit=True,
|
15 |
+
device_map="auto",
|
16 |
+
)
|
17 |
+
|
18 |
+
# load model after fine tuned on alpaca datasets
|
19 |
+
self.model = PeftModel.from_pretrained(
|
20 |
+
self.model, "Nelsonlin0321/alpaca-lora-7b-tuned-on-hk-csv-fqa_causal_lm")
|
21 |
+
|
22 |
+
self.eval_generation_config = GenerationConfig(
|
23 |
+
temperature=0.1,
|
24 |
+
top_p=0.75,
|
25 |
+
num_beams=4)
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def generate_prompt_eval(instruction):
|
29 |
+
template = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
30 |
+
### Instruction:
|
31 |
+
{instruction}
|
32 |
+
### Response:"""
|
33 |
+
return template
|
34 |
+
|
35 |
+
def __call__(self, instruction: str) -> str:
|
36 |
+
prompt = self.generate_prompt_eval(instruction)
|
37 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
38 |
+
input_ids = inputs["input_ids"].cuda()
|
39 |
+
generation_output = self.model.generate(
|
40 |
+
input_ids=input_ids,
|
41 |
+
generation_config=self.eval_generation_config,
|
42 |
+
return_dict_in_generate=True,
|
43 |
+
output_scores=True,
|
44 |
+
max_new_tokens=256
|
45 |
+
)
|
46 |
+
for s in generation_output.sequences:
|
47 |
+
output = self.tokenizer.decode(s)
|
48 |
+
output = output.split("### Response:")[1].strip()
|