chat / app.py
Yuzhang Huang iMac
initial commit
af6365d
raw
history blame
20.9 kB
import json
import re
import gradio as gr
import os
import sys
import requests
import csv
import datetime, pytz
import uuid
import pymysql
# init variables
user_key = "" # 在这里输入你的 API 密钥
username = os.environ.get('user')
password = os.environ.get('pass')
initial_prompt = "你是一个微信客服,需要接待从抖音过来具有保险需求的客户,目的是让客户提交信息收集表格(链接:https://info.link.haha/this-is-the-link.html),顺带解答一些简单的保险问题,不要回答太多,因为用户提交表格后会有专业规划师联系他们解答。你的第一句话是:“您好,请问是从抖音来的吗?”"
API_URL = "https://api.openai.com/v1/chat/completions"
HISTORY_DIR = "history"
TEMPLATES_DIR = "templates"
# get timestamp
def get_mmdd_hhmi():
tz = pytz.timezone('Asia/Shanghai')
now = datetime.datetime.now(tz)
timestamp = now.strftime("_%m%d_%H%M")
return timestamp
# get current time
def get_current_time():
tz = pytz.timezone('Asia/Shanghai')
now = datetime.datetime.now(tz)
timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
return timestamp
# db conn
def write_data(time, session_id, chat_id, api_key, round, system, user, assistant, messages, payload):
# connect to db
conn = pymysql.connect(host = os.environ.get('db_host'), port=3306,
user = os.environ.get('db_user'),
password = os.environ.get('db_pass'),
db = os.environ.get('db_db'),
charset = 'utf8mb4')
# create cursor
cur = conn.cursor()
# SQL
sql_update = f'''
INSERT INTO `chatlog` (`time`,`session_id`,`chat_id`,`api_key`,`round`,`system`,`user`,`assistant`,`messages`,`payload`)
VALUES ("{time}","{session_id}","{chat_id}","{api_key}","{round}","{system}","{user}","{assistant}","{messages}","{payload}")
'''
# print(sql_update)
try:
# insert data by update
cur.execute(sql_update)
conn.commit()
print('成功写入数据!')
except Exception as e:
conn.rollback()
print(f"出错了:{e}")
# close conn
cur.close()
conn.close()
# clear state
def get_empty_state():
return {"total_tokens": 0, "messages": []}
# uuid genetator
def gen_uuid():
return str(uuid.uuid4())
#if we are running in Docker
if os.environ.get('dockerrun') == 'yes':
dockerflag = True
else:
dockerflag = False
if dockerflag:
my_api_key = os.environ.get('OPENAI_API_KEY')
if my_api_key == "empty":
print("Please give a api key!")
sys.exit(1)
#auth
username = os.environ.get('user')
password = os.environ.get('pass')
if isinstance(username, type(None)) or isinstance(password, type(None)):
authflag = False
else:
authflag = True
# parse text for code
def parse_text(text):
text = re.sub('[<br>]+[<br>]', '<br>', text)
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
firstline = False
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="{items[-1]}" style="display: block; white-space: pre; padding: 0 1em 1em 1em; color: #fff; background: #000;">'
firstline = True
else:
lines[i] = f'</code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("&", "&amp;")
line = line.replace("\"", "`\"`")
line = line.replace("\'", "`\'`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("#", "&#35;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
lines[i] = "<br>"+line
text = "".join(lines)
return text
# 请求API
def openai_request(inputs, top_p, temperature, max_tokens, user_key, session_id, chat_id, chatbot=[], history=[], system_prompt=initial_prompt, regenerate=False, summary=False): # repetition_penalty, top_k
# user key > env key
api_key = user_key or os.environ.get("OPENAI_API_KEY")
# 对话轮数 = list 除 2,初始为 0
chat_round = len(history) // 2
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# 构造 messages list
messages = [compose_system(system_prompt)]
# 初始状态(chat_round == 0 == False)
if chat_round: # 不为 0 时
for data in chatbot:
temp1 = {}
temp1["role"] = "user"
temp1["content"] = data[0]
temp2 = {}
temp2["role"] = "assistant"
temp2["content"] = data[1]
if temp1["content"] != "":
messages.append(temp1)
messages.append(temp2)
else:
messages[-1]['content'] = temp2['content']
# 重新生成
if regenerate and chat_round:
messages.pop()
elif summary:
messages.append(compose_user(
"请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。"))
history = ["我们刚刚聊了什么?"]
else:
temp3 = {}
temp3["role"] = "user"
temp3["content"] = inputs
messages.append(temp3)
chat_round += 1
payload = {
"model": "gpt-3.5-turbo",
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
if not summary:
history.append(inputs)
# make a POST request to the API endpoint using the requests.post method, passing in stream=True
response = requests.post(API_URL, headers=headers,
json=payload, stream=True)
#response = requests.post(API_URL, headers=headers, json=payload, stream=True)
token_counter = 0
partial_words = ""
counter = 0
chatbot.append((history[-1], ""))
for chunk in response.iter_lines():
if counter == 0:
counter += 1
continue
counter += 1
# check whether each line is non-empty
if chunk:
# decode each line as response data is in bytes
try:
if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
break
except Exception as e:
chatbot.pop()
chatbot.append((history[-1], f"☹️发生了错误<br>返回值:{response.text}<br>异常:{e}"))
history.pop()
yield chatbot, history
break
#print(json.loads(chunk.decode()[6:])['choices'][0]["delta"] ["content"])
partial_words = partial_words + \
json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"]
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = parse_text(partial_words)
chatbot[-1] = (history[-2], history[-1])
# chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
token_counter += 1
# resembles {chatbot: chat, state: history}
yield chatbot, history
# logs
time = get_current_time()
print(f"\n\ntime: {time}")
print(f"session_id: {session_id}")
print(f"chat_id: {chat_id}")
print(f"apiKey: {api_key}")
print(f"round: {chat_round}")
print(f"system: {system_prompt}")
print(f"user: {inputs}")
print(f"assistant: {history[-1]}")
print(f"messages: {messages}")
paras = dict(payload)
del paras["messages"]
print(f"payload: {paras}")
# write data
write_data(time, session_id, chat_id, api_key, chat_round, system_prompt, inputs, history[-1], messages, paras)
def delete_last_conversation(chatbot, history):
chatbot.pop()
history.pop()
history.pop()
return chatbot, history
def save_chat_history(filename, system, history, chatbot):
if filename == "":
return
if not filename.endswith(".json"):
filename = filename + get_mmdd_hhmi() + ".json"
os.makedirs(HISTORY_DIR, exist_ok=True)
json_s = {"system": system, "history": history, "chatbot": chatbot}
with open(os.path.join(HISTORY_DIR, filename), "w", encoding="UTF-8") as f:
json.dump(json_s, f, ensure_ascii=False)
# load chat history item
def load_chat_history(filename):
try:
with open(os.path.join(HISTORY_DIR, filename), "r", encoding="UTF-8") as f:
json_s = json.load(f)
except:
with open(os.path.join(HISTORY_DIR, "default.json"), "r", encoding="UTF-8") as f:
json_s = json.load(f)
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
# update dropdown list with json files list in dir
def get_file_names(dir, plain=False, filetype=".json"):
# find all json files in the current directory and return their names
try:
files = [f for f in os.listdir(dir) if f.endswith(filetype)]
except FileNotFoundError:
files = []
if plain:
return files
else:
return gr.Dropdown.update(choices=files)
# construct history files list
def get_history_names(plain=False):
return get_file_names(HISTORY_DIR, plain)
# deprecated: load templates
def load_template(filename):
lines = []
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="UTF-8") as csvfile:
reader = csv.reader(csvfile)
lines = list(reader)
lines = lines[1:]
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=[row[0] for row in lines])
def get_template_names(plain=False):
return get_file_names(TEMPLATES_DIR, plain, filetype=".csv")
def reset_state():
return [], []
def compose_system(system_prompt):
return {"role": "system", "content": system_prompt}
def compose_user(user_input):
return {"role": "user", "content": user_input}
def reset_textbox():
return gr.update(value='')
css_styles = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#chatbox {min-height: 400px;}
#header {text-align: center;}
#prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
#ids{visibility: hidden;}
footer {visibility: hidden;}
"""
# build UI
with gr.Blocks(css=css_styles,title="Chatbot🚀") as interface:
history = gr.State([])
promptTemplates = gr.State({})
TRUECONSTANT = gr.State(True)
FALSECONSTANT = gr.State(False)
with gr.Column(elem_id="col-container"):
gr.Markdown("""# Chatbot Playground 🚀
本对话使用 GPT-3.5 Turbo 模型。
如果遇到问题请联系Simon!
""",
elem_id="header")
with gr.Row():
# left col: chat
with gr.Column():
# clear history
btnClearChat = gr.Button("🧹 清空对话")# refresh chat_id
# dialogue
chatbot = gr.Chatbot(elem_id="chatbox",label="对话") # .style(color_map=("#1D51EE", "#585A5B"))
# input
with gr.Row():
with gr.Column(scale=12):
txt = gr.Textbox(show_label=False, placeholder="在这里输入").style(
container=False)
with gr.Column(min_width=50, scale=1):
btnSubmit = gr.Button("🚀", variant="primary")
# shortcuts
with gr.Row():
btnRegen = gr.Button("🔄 重新生成")
btnDelPrev = gr.Button("🗑️ 删除上一轮")
btnSummarize = gr.Button("♻️ 总结对话")
# session id
with gr.Row(elem_id="ids"):
mdSessionId = gr.Textbox(label="Session ID",value=gen_uuid,elem_id="session_id",interactive=False)
mdChatId = gr.Textbox(label="Chat ID",value=gen_uuid,elem_id="chat_id",interactive=False)
# right col: settings
with gr.Column():
# api key
apiKey = gr.Textbox(show_label=True, placeholder=f"使用自己的OpenAI API key",
value='', label="API Key", type="text", visible=True).style(container=True)
# sys prompt
promptSystem = gr.Textbox(lines=2,show_label=True, placeholder=f"在这里输入System Prompt...",
label="系统输入", info="系统输入不会直接产生交互,可作为基础信息输入。", value=initial_prompt).style(container=True)
topic = gr.State("未命名对话")
# templates
# with gr.Accordion(label="加载模板", open=False):
# with gr.Column():
# with gr.Row():
# with gr.Column(scale=8):
# templateFileSelectDropdown = gr.Dropdown(label="1. 选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True),value="prompt.csv", multiselect=False)
# with gr.Column(scale=1):
# templateRefreshBtn = gr.Button("🔄 刷新列表")
# templaeFileReadBtn = gr.Button("📂 读取文件")
# with gr.Row():
# with gr.Column(scale=6):
# templateSelectDropdown = gr.Dropdown(label="2. 选择prompt", choices=[], multiselect=False)
# with gr.Column(scale=1):
# templateApplyBtn = gr.Button("⬇️ 引用prompt")
# load / save chat
with gr.Accordion(label="导入/导出", open=True):
# hint
gr.Markdown("注意:导出末尾不含 `.json` 则自动添加时间戳`_mmdd_hhmi`,文件名一致会进行覆盖!")
with gr.Row():
# load
with gr.Column():
historyFileSelectDropdown = gr.Dropdown(label="加载对话", choices=get_history_names(plain=True), value="model_0.json", multiselect=False)
btnLoadHist = gr.Button("📂 导入对话")
btnGetHistList = gr.Button("🔄 刷新列表")
# save
with gr.Column():
saveFileName = gr.Textbox(show_label=True, placeholder=f"文件名", label="导出文件名", value="对话").style(container=True)
btnSaveToHist = gr.Button("💾 导出对话")
# parameters: inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("参数配置", open=False):
gr.Markdown("""
* [GPT-3参数详解](https://blog.csdn.net/jarodyv/article/details/128984602)
* [API文档](https://platform.openai.com/docs/api-reference/chat/create)
""")
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.3, step=0.05,
interactive=True, label="Top-p (核采样)",
info="0.1意味着只考虑包含最高10%概率质量的token。较小时,更加紧凑连贯,但缺乏创造性。较大时,更加多样化,但会出现语法错误和不连贯。")
temperature = gr.Slider(minimum=-0, maximum=2.0, value=0.5,
step=0.1, interactive=True, label="采样温度",info="越高 → 不确定性、创造性越高↑")
max_tokens = gr.Slider(minimum=100, maximum=1500, value=500,
step=10, label="最大token数",info="每次回答最大token数,数值过大可能长篇大论,看起来不像正常聊天")
# context_length = gr.Slider(minimum=1, maximum=5, value=2,
# step=1, label="记忆轮数", info="每次用于记忆的对话轮数。注意:数值过大可能会导致token数巨大!")
presence_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="出现惩罚系数", info="越高 → 增加谈论新主题的可能性" )
frequency_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="频率惩罚系数", info="越高 → 降低模型逐字重复同一行的可能性" )
# gr.Markdown(description)
# event listener: enter
btnClearChat.click(gen_uuid, [], [mdChatId])
txt.submit(openai_request,
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history])
txt.submit(reset_textbox, [], [txt])
# button: submit
btnSubmit.click(openai_request,
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history], show_progress=True)
btnSubmit.click(reset_textbox, [], [txt])
# button: clear history
btnClearChat.click(reset_state, outputs=[chatbot, history])
# button: regenerate
btnRegen.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history,
promptSystem, TRUECONSTANT], [chatbot, history], show_progress=True)
# button: delete previous round
btnDelPrev.click(delete_last_conversation, [chatbot, history], [
chatbot, history], show_progress=True)
# button: summarize
btnSummarize.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history,
promptSystem, FALSECONSTANT, TRUECONSTANT], [chatbot, history], show_progress=True)
# button: save dialogue (save hist & get hist name)
btnSaveToHist.click(save_chat_history, [
saveFileName, promptSystem, history, chatbot], None, show_progress=True)
btnSaveToHist.click(get_history_names, None, [historyFileSelectDropdown])
# button: get history list
historyFileSelectDropdown.change(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot])
btnGetHistList.click(get_history_names, None, [historyFileSelectDropdown])
# button: load history
btnLoadHist.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot], show_progress=True)
# templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
# templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
# templateApplyBtn.click(lambda x, y: x[y], [promptTemplates, templateSelectDropdown], [promptSystem], show_progress=True)
# print("访问:http://localhost:7860")
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
#if running in Docker
if dockerflag:
if authflag:
interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860,auth=(username, password))
else:
interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860, share=False)
#if not running in Docker
else:
if username != None and password != None:
interface.queue(concurrency_count=2,status_update_rate="auto").launch(height='800px',debug=True,show_error=True,
auth=(username, password),auth_message="请联系Simon获取用户名与密码!")
else:
interface.queue(concurrency_count=2,status_update_rate="auto").launch(height='800px',debug=True,show_error=True)