File size: 4,350 Bytes
484b605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import sys
import json

import fire
import torch
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

from utils.prompter import Prompter

if torch.cuda.is_available():
    device = "cuda"


class Infer():
    def __init__(

        self,

        load_8bit: bool = False,

        base_model: str = "",

        lora_weights: str = "",

        prompt_template: str = "",  # The prompt template to use, will default to alpaca.

    ):
        prompter = Prompter(prompt_template)
        tokenizer = LlamaTokenizer.from_pretrained(base_model)
        model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        
        try:
            print(f"Using lora {lora_weights}")
            model = PeftModel.from_pretrained(
                model,
                lora_weights,
                torch_dtype=torch.float16,
            )
        except:
            print("*"*50, "\n Attention! No Lora Weights \n", "*"*50)
            
        # unwind broken decapoda-research config
        model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
        model.config.bos_token_id = 1
        model.config.eos_token_id = 2
        if not load_8bit:
            model.half()  # seems to fix bugs for some users.

        model.eval()

        if torch.__version__ >= "2" and sys.platform != "win32":
            model = torch.compile(model)
        
        self.base_model = base_model
        self.lora_weights = lora_weights
        self.model = model
        self.prompter = prompter
        self.tokenizer = tokenizer

    def generate_output(

        self,

        instruction,

        input=None,

        temperature=0.1,

        top_p=0.75,

        top_k=40,

        num_beams=1,

        max_new_tokens=256,

        **kwargs,

    ):
        prompt = self.prompter.generate_prompt(instruction, input)
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            # repetition_penalty=10.0,
            **kwargs,
        )
        with torch.no_grad():
            generation_output = self.model.generate(
                input_ids=input_ids,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=max_new_tokens,
            )
        s = generation_output.sequences[0]
        output = self.tokenizer.decode(s)
        return self.prompter.get_response(output)

    def infer_from_file(self, infer_data_path):
        with open(infer_data_path) as f:
            for line in f:
                data = json.loads(line)
                instruction = data["instruction"]
                output = data["output"]
                print('=' * 100)
                print(f"Base Model: {self.base_model}    Lora Weights: {self.lora_weights}")
                print("Instruction:\n", instruction)
                model_output = self.generate_output(instruction)
                print("Model Output:\n", model_output)
                print("Ground Truth:\n", output)
                print('=' * 100)


def main(

    load_8bit: bool = False,

    base_model: str = "",

    lora_weights: str = "",

    prompt_template: str = "",  # The prompt template to use, will default to alpaca.

    infer_data_path: str = "",

):
    infer = Infer(
        load_8bit=load_8bit,
        base_model=base_model,
        lora_weights=lora_weights,
        prompt_template=prompt_template
    )
    
    try:
        infer.infer_from_file(infer_data_path)
    except Exception as e:
        print(e, "Read infer_data_path Failed! Now Interactive Mode: ")
        while True:
            print('=' * 100)
            instruction = input("请输入您的问题: ")
            print("LaWGPT:")
            print(infer.generate_output(instruction))
            print('=' * 100)


if __name__ == "__main__":
    fire.Fire(main)