File size: 3,629 Bytes
fe2ff51
 
b02a933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769970e
b02a933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769970e
b02a933
 
5a307f6
 
b02a933
769970e
 
 
b02a933
 
 
 
 
 
 
 
fe2ff51
 
769970e
fe2ff51
 
 
 
 
 
 
b02a933
 
 
 
 
 
 
 
 
fe2ff51
b02a933
fe2ff51
 
 
 
769970e
fe2ff51
 
 
 
 
 
769970e
 
 
fe2ff51
769970e
 
fe2ff51
 
 
 
 
 
 
b02a933
5a307f6
 
fe2ff51
 
 
 
 
 
 
 
 
 
 
 
b02a933
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr




# from huggingface_hub import InferenceClient
# """
# For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
# """
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")




from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class ChatClient:
    def __init__(self, model_path):
        """
        初始化客户端,加载模型和分词器到 GPU(如果可用)。
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
        self.model.eval()  # 设置为评估模式

    def chat_completion(self, messages, max_tokens, stream=False, temperature=1.0, top_p=1.0):
        """
        生成对话回复。
        """
        # 将所有输入消息合并为一个字符串
        input_text = messages
        print(input_text)
        # 使用分词器处理输入文本
        inputs = self.tokenizer(input_text, return_tensors='pt').to(self.device)

        # 设置生成的参数
        gen_kwargs = {
            "max_length": inputs['input_ids'].shape[1] + max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True
        }

        # 使用生成器生成文本
        # output_sequences = self.model.generate(**inputs, **gen_kwargs)

        # 解码生成的文本
        # result_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
        # yield result_text

        # claude 3.5
        for token in self.model.generate(**inputs, **gen_kwargs, streamer=None):
            yield self.tokenizer.decode(token, skip_special_tokens=True)

# 创建客户端实例,指定模型路径
model_path = 'model/v3/'
client = ChatClient(model_path)






def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # messages = [{"role": "system", "content": system_message}]
    #
    # for val in history:
    #     if val[0]:
    #         messages.append({"role": "user", "content": val[0]})
    #     if val[1]:
    #         messages.append({"role": "assistant", "content": val[1]})
    #
    # messages.append({"role": "user", "content": message})

    messages = system_message + message


    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        print(message)
        token = message
        #token = message.choices[0].delta.content

        response += token
        yield response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Yahoo!ショッピングについての質問を回答してください。", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()