|
import json |
|
import re |
|
import gradio as gr |
|
import os |
|
import sys |
|
import requests |
|
import csv |
|
import datetime, pytz |
|
import uuid |
|
import pymysql |
|
|
|
|
|
|
|
user_key = "" |
|
username = os.environ.get('user') |
|
password = os.environ.get('pass') |
|
|
|
initial_prompt = "你是一个微信客服,需要接待从抖音过来具有保险需求的客户,目的是让客户提交信息收集表格(链接:https://info.link.haha/this-is-the-link.html),顺带解答一些简单的保险问题,不要回答太多,因为用户提交表格后会有专业规划师联系他们解答。你的第一句话是:“您好,请问是从抖音来的吗?”" |
|
API_URL = "https://api.openai.com/v1/chat/completions" |
|
HISTORY_DIR = "history" |
|
TEMPLATES_DIR = "templates" |
|
|
|
|
|
def get_mmdd_hhmi(): |
|
tz = pytz.timezone('Asia/Shanghai') |
|
now = datetime.datetime.now(tz) |
|
timestamp = now.strftime("_%m%d_%H%M") |
|
return timestamp |
|
|
|
|
|
def get_current_time(): |
|
tz = pytz.timezone('Asia/Shanghai') |
|
now = datetime.datetime.now(tz) |
|
timestamp = now.strftime("%Y-%m-%d %H:%M:%S") |
|
return timestamp |
|
|
|
|
|
def write_data(time, session_id, chat_id, api_key, round, system, user, assistant, messages, payload): |
|
|
|
|
|
conn = pymysql.connect(host = os.environ.get('db_host'), port=3306, |
|
user = os.environ.get('db_user'), |
|
password = os.environ.get('db_pass'), |
|
db = os.environ.get('db_db'), |
|
charset = 'utf8mb4') |
|
|
|
|
|
cur = conn.cursor() |
|
|
|
|
|
sql_update = f''' |
|
INSERT INTO `chatlog` (`time`,`session_id`,`chat_id`,`api_key`,`round`,`system`,`user`,`assistant`,`messages`,`payload`) |
|
VALUES ("{time}","{session_id}","{chat_id}","{api_key}","{round}","{system}","{user}","{assistant}","{messages}","{payload}") |
|
''' |
|
|
|
|
|
|
|
try: |
|
|
|
cur.execute(sql_update) |
|
conn.commit() |
|
print('成功写入数据!') |
|
except Exception as e: |
|
conn.rollback() |
|
print(f"出错了:{e}") |
|
|
|
|
|
cur.close() |
|
conn.close() |
|
|
|
|
|
def get_empty_state(): |
|
return {"total_tokens": 0, "messages": []} |
|
|
|
|
|
def gen_uuid(): |
|
return str(uuid.uuid4()) |
|
|
|
|
|
if os.environ.get('dockerrun') == 'yes': |
|
dockerflag = True |
|
else: |
|
dockerflag = False |
|
|
|
if dockerflag: |
|
my_api_key = os.environ.get('OPENAI_API_KEY') |
|
if my_api_key == "empty": |
|
print("Please give a api key!") |
|
sys.exit(1) |
|
|
|
username = os.environ.get('user') |
|
password = os.environ.get('pass') |
|
if isinstance(username, type(None)) or isinstance(password, type(None)): |
|
authflag = False |
|
else: |
|
authflag = True |
|
|
|
|
|
def parse_text(text): |
|
text = re.sub('[<br>]+[<br>]', '<br>', text) |
|
lines = text.split("\n") |
|
lines = [line for line in lines if line != ""] |
|
count = 0 |
|
firstline = False |
|
for i, line in enumerate(lines): |
|
if "```" in line: |
|
count += 1 |
|
items = line.split('`') |
|
if count % 2 == 1: |
|
lines[i] = f'<pre><code class="{items[-1]}" style="display: block; white-space: pre; padding: 0 1em 1em 1em; color: #fff; background: #000;">' |
|
firstline = True |
|
else: |
|
lines[i] = f'</code></pre>' |
|
else: |
|
if i > 0: |
|
if count % 2 == 1: |
|
line = line.replace("&", "&") |
|
line = line.replace("\"", "`\"`") |
|
line = line.replace("\'", "`\'`") |
|
line = line.replace("<", "<") |
|
line = line.replace(">", ">") |
|
line = line.replace(" ", " ") |
|
line = line.replace("*", "*") |
|
line = line.replace("_", "_") |
|
line = line.replace("#", "#") |
|
line = line.replace("-", "-") |
|
line = line.replace(".", ".") |
|
line = line.replace("!", "!") |
|
line = line.replace("(", "(") |
|
line = line.replace(")", ")") |
|
lines[i] = "<br>"+line |
|
text = "".join(lines) |
|
return text |
|
|
|
|
|
def openai_request(inputs, top_p, temperature, max_tokens, user_key, session_id, chat_id, chatbot=[], history=[], system_prompt=initial_prompt, regenerate=False, summary=False): |
|
|
|
|
|
api_key = user_key or os.environ.get("OPENAI_API_KEY") |
|
|
|
|
|
chat_round = len(history) // 2 |
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {api_key}" |
|
} |
|
|
|
|
|
messages = [compose_system(system_prompt)] |
|
|
|
|
|
if chat_round: |
|
for data in chatbot: |
|
temp1 = {} |
|
temp1["role"] = "user" |
|
temp1["content"] = data[0] |
|
temp2 = {} |
|
temp2["role"] = "assistant" |
|
temp2["content"] = data[1] |
|
if temp1["content"] != "": |
|
messages.append(temp1) |
|
messages.append(temp2) |
|
else: |
|
messages[-1]['content'] = temp2['content'] |
|
|
|
|
|
if regenerate and chat_round: |
|
messages.pop() |
|
elif summary: |
|
messages.append(compose_user( |
|
"请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。")) |
|
history = ["我们刚刚聊了什么?"] |
|
else: |
|
temp3 = {} |
|
temp3["role"] = "user" |
|
temp3["content"] = inputs |
|
messages.append(temp3) |
|
chat_round += 1 |
|
|
|
payload = { |
|
"model": "gpt-3.5-turbo", |
|
"messages": messages, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens, |
|
"top_p": top_p, |
|
"n": 1, |
|
"stream": True, |
|
"presence_penalty": 0, |
|
"frequency_penalty": 0, |
|
} |
|
|
|
if not summary: |
|
history.append(inputs) |
|
|
|
response = requests.post(API_URL, headers=headers, |
|
json=payload, stream=True) |
|
|
|
|
|
token_counter = 0 |
|
partial_words = "" |
|
|
|
counter = 0 |
|
chatbot.append((history[-1], "")) |
|
|
|
for chunk in response.iter_lines(): |
|
if counter == 0: |
|
counter += 1 |
|
continue |
|
counter += 1 |
|
|
|
if chunk: |
|
|
|
try: |
|
if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0: |
|
break |
|
except Exception as e: |
|
chatbot.pop() |
|
chatbot.append((history[-1], f"☹️发生了错误<br>返回值:{response.text}<br>异常:{e}")) |
|
history.pop() |
|
yield chatbot, history |
|
break |
|
|
|
partial_words = partial_words + \ |
|
json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"] |
|
if token_counter == 0: |
|
history.append(" " + partial_words) |
|
else: |
|
history[-1] = parse_text(partial_words) |
|
chatbot[-1] = (history[-2], history[-1]) |
|
|
|
token_counter += 1 |
|
|
|
yield chatbot, history |
|
|
|
|
|
time = get_current_time() |
|
print(f"\n\ntime: {time}") |
|
print(f"session_id: {session_id}") |
|
print(f"chat_id: {chat_id}") |
|
print(f"apiKey: {api_key}") |
|
print(f"round: {chat_round}") |
|
print(f"system: {system_prompt}") |
|
print(f"user: {inputs}") |
|
print(f"assistant: {history[-1]}") |
|
print(f"messages: {messages}") |
|
paras = dict(payload) |
|
del paras["messages"] |
|
print(f"payload: {paras}") |
|
|
|
|
|
write_data(time, session_id, chat_id, api_key, chat_round, system_prompt, inputs, history[-1], messages, paras) |
|
|
|
|
|
def delete_last_conversation(chatbot, history): |
|
chatbot.pop() |
|
history.pop() |
|
history.pop() |
|
return chatbot, history |
|
|
|
def save_chat_history(filename, system, history, chatbot): |
|
if filename == "": |
|
return |
|
if not filename.endswith(".json"): |
|
filename = filename + get_mmdd_hhmi() + ".json" |
|
os.makedirs(HISTORY_DIR, exist_ok=True) |
|
json_s = {"system": system, "history": history, "chatbot": chatbot} |
|
with open(os.path.join(HISTORY_DIR, filename), "w", encoding="UTF-8") as f: |
|
json.dump(json_s, f, ensure_ascii=False) |
|
|
|
|
|
def load_chat_history(filename): |
|
try: |
|
with open(os.path.join(HISTORY_DIR, filename), "r", encoding="UTF-8") as f: |
|
json_s = json.load(f) |
|
except: |
|
with open(os.path.join(HISTORY_DIR, "default.json"), "r", encoding="UTF-8") as f: |
|
json_s = json.load(f) |
|
return filename, json_s["system"], json_s["history"], json_s["chatbot"] |
|
|
|
|
|
def get_file_names(dir, plain=False, filetype=".json"): |
|
|
|
try: |
|
files = [f for f in os.listdir(dir) if f.endswith(filetype)] |
|
except FileNotFoundError: |
|
files = [] |
|
if plain: |
|
return files |
|
else: |
|
return gr.Dropdown.update(choices=files) |
|
|
|
|
|
def get_history_names(plain=False): |
|
return get_file_names(HISTORY_DIR, plain) |
|
|
|
|
|
def load_template(filename): |
|
lines = [] |
|
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="UTF-8") as csvfile: |
|
reader = csv.reader(csvfile) |
|
lines = list(reader) |
|
lines = lines[1:] |
|
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=[row[0] for row in lines]) |
|
|
|
def get_template_names(plain=False): |
|
return get_file_names(TEMPLATES_DIR, plain, filetype=".csv") |
|
|
|
def reset_state(): |
|
return [], [] |
|
|
|
|
|
def compose_system(system_prompt): |
|
return {"role": "system", "content": system_prompt} |
|
|
|
|
|
def compose_user(user_input): |
|
return {"role": "user", "content": user_input} |
|
|
|
|
|
def reset_textbox(): |
|
return gr.update(value='') |
|
|
|
|
|
css_styles = """ |
|
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;} |
|
#chatbox {min-height: 400px;} |
|
#header {text-align: center;} |
|
#prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px;} |
|
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;} |
|
#label {font-size: 0.8em; padding: 0.5em; margin: 0;} |
|
#ids{visibility: hidden;} |
|
footer {visibility: hidden;} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css_styles,title="Chatbot🚀") as interface: |
|
|
|
history = gr.State([]) |
|
promptTemplates = gr.State({}) |
|
TRUECONSTANT = gr.State(True) |
|
FALSECONSTANT = gr.State(False) |
|
|
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown("""# Chatbot Playground 🚀 |
|
本对话使用 GPT-3.5 Turbo 模型。 |
|
如果遇到问题请联系Simon! |
|
""", |
|
elem_id="header") |
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
btnClearChat = gr.Button("🧹 清空对话") |
|
|
|
|
|
chatbot = gr.Chatbot(elem_id="chatbox",label="对话") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
txt = gr.Textbox(show_label=False, placeholder="在这里输入").style( |
|
container=False) |
|
with gr.Column(min_width=50, scale=1): |
|
btnSubmit = gr.Button("🚀", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
btnRegen = gr.Button("🔄 重新生成") |
|
btnDelPrev = gr.Button("🗑️ 删除上一轮") |
|
btnSummarize = gr.Button("♻️ 总结对话") |
|
|
|
|
|
with gr.Row(elem_id="ids"): |
|
mdSessionId = gr.Textbox(label="Session ID",value=gen_uuid,elem_id="session_id",interactive=False) |
|
mdChatId = gr.Textbox(label="Chat ID",value=gen_uuid,elem_id="chat_id",interactive=False) |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
apiKey = gr.Textbox(show_label=True, placeholder=f"使用自己的OpenAI API key", |
|
value='', label="API Key", type="text", visible=True).style(container=True) |
|
|
|
|
|
promptSystem = gr.Textbox(lines=2,show_label=True, placeholder=f"在这里输入System Prompt...", |
|
label="系统输入", info="系统输入不会直接产生交互,可作为基础信息输入。", value=initial_prompt).style(container=True) |
|
|
|
topic = gr.State("未命名对话") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Accordion(label="导入/导出", open=True): |
|
|
|
|
|
gr.Markdown("注意:导出末尾不含 `.json` 则自动添加时间戳`_mmdd_hhmi`,文件名一致会进行覆盖!") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
historyFileSelectDropdown = gr.Dropdown(label="加载对话", choices=get_history_names(plain=True), value="model_0.json", multiselect=False) |
|
btnLoadHist = gr.Button("📂 导入对话") |
|
btnGetHistList = gr.Button("🔄 刷新列表") |
|
|
|
with gr.Column(): |
|
saveFileName = gr.Textbox(show_label=True, placeholder=f"文件名", label="导出文件名", value="对话").style(container=True) |
|
btnSaveToHist = gr.Button("💾 导出对话") |
|
|
|
|
|
|
|
with gr.Accordion("参数配置", open=False): |
|
gr.Markdown(""" |
|
* [GPT-3参数详解](https://blog.csdn.net/jarodyv/article/details/128984602) |
|
* [API文档](https://platform.openai.com/docs/api-reference/chat/create) |
|
""") |
|
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.3, step=0.05, |
|
interactive=True, label="Top-p (核采样)", |
|
info="0.1意味着只考虑包含最高10%概率质量的token。较小时,更加紧凑连贯,但缺乏创造性。较大时,更加多样化,但会出现语法错误和不连贯。") |
|
temperature = gr.Slider(minimum=-0, maximum=2.0, value=0.5, |
|
step=0.1, interactive=True, label="采样温度",info="越高 → 不确定性、创造性越高↑") |
|
max_tokens = gr.Slider(minimum=100, maximum=1500, value=500, |
|
step=10, label="最大token数",info="每次回答最大token数,数值过大可能长篇大论,看起来不像正常聊天") |
|
|
|
|
|
presence_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="出现惩罚系数", info="越高 → 增加谈论新主题的可能性" ) |
|
frequency_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="频率惩罚系数", info="越高 → 降低模型逐字重复同一行的可能性" ) |
|
|
|
|
|
|
|
btnClearChat.click(gen_uuid, [], [mdChatId]) |
|
|
|
txt.submit(openai_request, |
|
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history]) |
|
txt.submit(reset_textbox, [], [txt]) |
|
|
|
|
|
btnSubmit.click(openai_request, |
|
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history], show_progress=True) |
|
btnSubmit.click(reset_textbox, [], [txt]) |
|
|
|
|
|
btnClearChat.click(reset_state, outputs=[chatbot, history]) |
|
|
|
|
|
btnRegen.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, |
|
promptSystem, TRUECONSTANT], [chatbot, history], show_progress=True) |
|
|
|
|
|
btnDelPrev.click(delete_last_conversation, [chatbot, history], [ |
|
chatbot, history], show_progress=True) |
|
|
|
|
|
btnSummarize.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, |
|
promptSystem, FALSECONSTANT, TRUECONSTANT], [chatbot, history], show_progress=True) |
|
|
|
|
|
btnSaveToHist.click(save_chat_history, [ |
|
saveFileName, promptSystem, history, chatbot], None, show_progress=True) |
|
btnSaveToHist.click(get_history_names, None, [historyFileSelectDropdown]) |
|
|
|
|
|
historyFileSelectDropdown.change(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot]) |
|
btnGetHistList.click(get_history_names, None, [historyFileSelectDropdown]) |
|
|
|
|
|
btnLoadHist.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot], show_progress=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dockerflag: |
|
if authflag: |
|
interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860,auth=(username, password)) |
|
else: |
|
interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|
|
else: |
|
if username != None and password != None: |
|
interface.queue(concurrency_count=2,status_update_rate="auto").launch(height='800px',debug=True,show_error=True, |
|
auth=(username, password),auth_message="请联系Simon获取用户名与密码!") |
|
else: |
|
interface.queue(concurrency_count=2,status_update_rate="auto").launch(height='800px',debug=True,show_error=True) |
|
|