File size: 8,368 Bytes
1155f19
 
d2742cf
 
e65d733
 
 
 
 
 
2b92edc
6cf18af
 
e65d733
 
 
0143ad2
 
 
c7b6b2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0143ad2
c7b6b2d
0143ad2
e65d733
 
 
c7b6b2d
e65d733
 
 
 
 
c7b6b2d
0143ad2
 
 
 
6cf18af
0143ad2
c7b6b2d
e65d733
c7b6b2d
0143ad2
e65d733
 
c7b6b2d
 
 
 
6cf18af
e65d733
 
 
2b92edc
6cf18af
2b92edc
 
 
 
6cf18af
2b92edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf18af
 
2b92edc
 
6cf18af
2b92edc
 
6cf18af
 
 
 
 
 
 
 
 
2b92edc
 
 
6cf18af
2b92edc
 
6cf18af
2b92edc
 
 
6cf18af
 
 
 
 
 
 
 
 
 
 
 
 
 
2b92edc
 
 
6cf18af
 
 
 
2b92edc
 
 
 
e65d733
 
 
 
 
 
 
 
c7b6b2d
 
e65d733
c7b6b2d
e65d733
c7b6b2d
e65d733
 
 
 
 
 
 
 
 
 
c7b6b2d
e65d733
 
 
0143ad2
 
e65d733
 
 
 
6cf18af
0143ad2
e65d733
 
6cf18af
e65d733
 
 
c7b6b2d
 
e65d733
 
c7b6b2d
e65d733
c7b6b2d
e65d733
 
 
 
 
 
 
 
c7b6b2d
e65d733
2b92edc
 
6cf18af
 
 
 
2b92edc
e65d733
c7b6b2d
e65d733
6cf18af
 
c7b6b2d
e65d733
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from fix_int8 import fix_pytorch_int8
fix_pytorch_int8()


# Credit:
# https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py


import torch
import gradio as gr
from threading import Thread
from model import model, tokenizer
from session import db, logger, log_sys_info
from transformers import AutoTokenizer, GenerationConfig, AutoModel


max_length = 224
default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]

gr_title = """<h1 align="center">KumaGLM</h1>
<h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
<p align='center'>采样范围 2020/06/13 - 2023/04/15</p>
<p align='center'>GitHub Repo: <a class="github-button" href="https://github.com/KumaTea/ChatGLM" aria-label="Star KumaTea/ChatGLM on GitHub">KumaTea/ChatGLM</a></p>
<script async defer src="https://buttons.github.io/buttons.js"></script>
"""
gr_footer =  """<p align='center'>
本项目基于
<a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
,模型采用
<a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>

</p>
<p align='center'>
<em>每天起床第一句!</em>
</p>"""


def evaluate(context, temperature, top_p):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        # top_k=top_k,
        #repetition_penalty=1.1,
        num_beams=1,
        do_sample=True,
    )
    with torch.no_grad():
        # input_text = f"Context: {context}Answer: " 
        # input_text = '||'.join(default_start) + '||'
        # No need for starting prompt in API
        if not context.endswith('||'):
            context += '||'
        # logger.info('[API] Request: ' + context)
        ids = tokenizer([context], return_tensors="pt")
        inputs = ids.to("cpu")
        out = model.generate(
            **inputs,
            max_length=max_length,
            generation_config=generation_config
        )
        out = out.tolist()[0]
        decoder_output = tokenizer.decode(out)
        # out_text = decoder_output.split("Answer: ")[1]
        out_text = decoder_output
        logger.info('[API] Results: ' + out_text.replace('\n', '<br>'))
        return out_text


def evaluate_wrapper(context, temperature, top_p):
    db.lock()
    index = db.index
    db.set(index, prompt=context)
    result = evaluate(context, temperature, top_p)
    db.set(index, result=result)
    db.unlock()
    return result


def api_wrapper(context='', temperature=0.5, top_p=0.8, query=0):
    query = int(query)
    assert context or query
    
    return_json = {
        'status': '',
        'code': 0,
        'message': '',
        'index': 0,
        'result': ''
    }

    if context:
        if db.islocked():
            logger.info(f'[API] Request: {context}, Status: busy')
            return_json['status'] = 'busy'
            return_json['code'] = 503
            return_json['message'] = '[context] Server is busy, please try again later.'
            return return_json
        else:
            for index in db.prompts:
                if db.prompts[index] == context:
                    return_json['status'] = 'done'
                    return_json['code'] = 200
                    return_json['message'] = '[context] Request cached.'
                    return_json['index'] = index
                    return_json['result'] = db.results[index]
                    return return_json
            # new
            index = db.index
            t = Thread(target=evaluate_wrapper, args=(context, temperature, top_p))
            t.start()
            logger.info(f'[API] Request: {context}, Status: processing, Index: {index}')
            return_json['status'] = 'processing'
            return_json['code'] = 202
            return_json['message'] = '[context] Request accepted, please check back later.'
            return_json['index'] = index
            return return_json
    else:  # query
        if query in db.prompts and query in db.results:
            logger.info(f'[API] Query: {query}, Status: hit')
            return_json['status'] = 'done'
            return_json['code'] = 200
            return_json['message'] = '[query] Request processed.'
            return_json['index'] = query
            return_json['result'] = db.results[query]
            return return_json
        else:
            if db.islocked():
                logger.info(f'[API] Query: {query}, Status: processing')
                return_json['status'] = 'processing'
                return_json['code'] = 202
                return_json['message'] = '[query] Request in processing, please check back later.'
                return_json['index'] = query
                return return_json
            else:
                logger.info(f'[API] Query: {query}, Status: error')
                return_json['status'] = 'error'
                return_json['code'] = 404
                return_json['message'] = '[query] Index not found.'
                return_json['index'] = query
                return return_json


def evaluate_stream(msg, history, temperature, top_p):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        #repetition_penalty=1.1,
        num_beams=1,
        do_sample=True,
    )
    if not msg:
        msg = '……'

    history.append([msg, ""])

    context = '||'.join(default_start) + '||'
    if len(history) > 4:
        history.pop(0)

    for j in range(len(history)):
        history[j][0] = history[j][0].replace("<br>", "")

    # concatenate context
    for h in history[:-1]:
        context += h[0] + "||" + h[1] + "||"

    context += history[-1][0] + "||"
    context = context.replace(r'<br>', '')

    # TODO: Avoid the tokens are too long.
    # CUTOFF = 224
    while len(tokenizer.encode(context)) > max_length:
        # save 15 token size for the answer
        context = context[15:]

    h = []
    logger.info('[UI] Request: ' + context)
    for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature):
        history[-1][1] = response
        yield history, ""
    logger.info('[UI] Results: ' + response.replace('\n', '<br>'))


with gr.Blocks() as demo:
    gr.HTML(gr_title)
    # state = gr.State()
    with gr.Row():
        with gr.Column(scale=2):
            temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.5, label="Temperature",
                info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
            top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.8, label="Top-p",
                info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
            #code = gr.Textbox(label="temp_output", info="解码器输出")
            #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
            #    info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="聊天框", info="")
            msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
                info="输入你的内容,按 [Enter] 发送。什么都不填经常会出错。")
            clear = gr.Button("清除聊天")

        api_handler = gr.Button("API", visible=False)
        api_index = gr.Number(visible=False)
        api_result = gr.JSON(visible=False)
        info_handler = gr.Button("Info", visible=False)
        info_text = gr.Textbox('System info logged. Check it in the log viewer.', visible=False)


    msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
    clear.click(lambda: None, None, chatbot, queue=False)
    api_handler.click(api_wrapper, [msg, temp, top_p, api_index], api_result, api_name='chat')
    info_handler.click(log_sys_info, None, info_text, api_name='info')
    gr.HTML(gr_footer)

demo.queue()
demo.launch(debug=False)