# -- coding: utf-8 --
import json
import re
import gradio as gr
import os
import sys
import requests
import csv
import datetime, pytz
import uuid
import pymysql
# init variables
user_key = "" # 在这里输入你的 API 密钥
# username = os.environ.get('user')
# password = os.environ.get('pass')
API_URL = "https://api.openai.com/v1/chat/completions"
API_URL = "https://yuns.deno.dev/v1/chat/completions"
API_URL = "https://openai.hex.im/v1/chat/completions"
HISTORY_DIR = "history"
TEMPLATES_DIR = "templates"
# initial_prompt = f"现在是{get_current_time()}。你的目的是如实解答用户问题,对于不知道的问题,你需要回复你不知道。有些答案存在时效性。表格内容直接以HTML格式输出。"
def gen_sys_prompt():
tz = pytz.timezone('Asia/Shanghai')
now = datetime.datetime.now(tz)
timestamp = now.strftime("%Y年%m月%d日 %H:%M")
sys_prompt = f"""你的目的是如实解答用户问题,对于不知道的问题,你需要回复你不知道。
现在是{timestamp},有些信息可能已经过时。
表格的效果展示直接以HTML格式输出而非markdown格式。"""
return sys_prompt
# get timestamp
def get_mmdd_hhmi():
tz = pytz.timezone('Asia/Shanghai')
now = datetime.datetime.now(tz)
timestamp = now.strftime("_%m%d_%H%M")
return timestamp
# get current time
def get_current_time():
tz = pytz.timezone('Asia/Shanghai')
now = datetime.datetime.now(tz)
timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
return timestamp
# db fetch
def fetch_data(query):
# connect to db
conn = pymysql.connect(host = os.environ.get('db_host'), port=3306,
user = os.environ.get('db_user'),
password = os.environ.get('db_pass'),
db = os.environ.get('db_db'),
charset = 'utf8mb4')
cur = conn.cursor()
cur.execute(query)
result = cur.fetchall()
cur.close()
conn.close()
return result
# get user_tuple
try:
user_tuple = fetch_data("SELECT username, password FROM credentials")
source = "db"
except:
user_tuple = eval(os.environ.get('user_tuple'))
source = "env"
print(f"{source}: {user_tuple}")
# auth check
def auth_check(username, password):
global logged_in_user
if_pass = False
for user in user_tuple:
# print(user)
if user[0] == username and user[1] == password:
if_pass = True
logged_in_user = username
break
if if_pass:
print(f"Logged in as {logged_in_user}")
else:
print(f"Login attempt failed:[{username},{password}]")
return if_pass
# db write
def write_data(time, session_id, chat_id, api_key, round, system, user, assistant, messages, payload, username):
# connect to db
conn = pymysql.connect(host = os.environ.get('db_host'), port=3306,
user = os.environ.get('db_user'),
password = os.environ.get('db_pass'),
db = os.environ.get('db_db'),
charset = 'utf8mb4')
# create cursor
cur = conn.cursor()
# SQL
sql_update = f'''
INSERT INTO `chatlog` (`time`,`session_id`,`chat_id`,`api_key`,`round`,`system`,`user`,`assistant`,`messages`,`payload`,`username`)
VALUES ("{time}","{session_id}","{chat_id}","{api_key}","{round}","{system}","{user}","{assistant}","{messages}","{payload}","{username}")
'''
print(sql_update)
try:
# insert data by update
cur.execute(sql_update)
conn.commit()
print('成功写入数据!')
except Exception as e:
conn.rollback()
print(f"出错了:{e}")
# close conn
cur.close()
conn.close()
# clear state
def get_empty_state():
return {"total_tokens": 0, "messages": []}
# uuid genetator
def gen_uuid():
return str(uuid.uuid4())
#if we are running in Docker
if os.environ.get('dockerrun') == 'yes':
dockerflag = True
else:
dockerflag = False
if dockerflag:
my_api_key = os.environ.get('OPENAI_API_KEY')
if my_api_key == "empty":
print("Please give a api key!")
sys.exit(1)
#auth
# username = os.environ.get('user')
# password = os.environ.get('pass')
# if isinstance(username, type(None)) or isinstance(password, type(None)):
# authflag = False
# else:
# authflag = True
# parse text for code
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0 # locate the start & end of the code block
in_code_block = False # track if we're currently in a code block
for i, line in enumerate(lines):
# code block indicator
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1: # start
lines[i] = f"
"
in_code_block = True
else: # end
lines[i] = f'
'
in_code_block = False
# code block content
elif in_code_block:
line = line.replace("&", "&")
# line = line.replace("\"", "`\"`")
# line = line.replace("\'", "`\'`")
line = line.replace("'","\'")
line = line.replace('"','\"')
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("#", "#")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
lines[i] = line + "
"
# normal text
else:
line += "\n"
# print(f"line: {line}")
# print(f"lines: {lines}")
text = "".join(lines)
text = re.sub(r'
+
', '
', text)
return text
# 请求API
def openai_request(inputs, top_p, temperature, max_tokens, user_key, session_id, chat_id, chatbot=[], history=[], system_prompt=gen_sys_prompt, regenerate=False, summary=False): # repetition_penalty, top_k
# user key > env key
api_key = user_key or os.environ.get("OPENAI_API_KEY")
time = get_current_time()
print(f"\n\ntime: {time}")
# print(f"session_id: {session_id}")
# print(f"chat_id: {chat_id}")
print(f"user: {logged_in_user}")
print(f"apiKey: {api_key}")
sys_prompt_print = system_prompt.replace("\n","").replace(" ","")
print(f"system: {sys_prompt_print}")
print(f"user: {inputs}")
# 对话轮数 = list 除 2,初始为 0
chat_round = len(history) // 2
print(f"round: {chat_round}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# 构造 messages list
messages = [compose_system(system_prompt)]
# 初始状态(chat_round == 0 == False)
if chat_round: # 不为 0 时
for data in chatbot:
temp1 = {}
temp1["role"] = "user"
temp1["content"] = data[0]
temp2 = {}
temp2["role"] = "assistant"
temp2["content"] = data[1]
if temp1["content"] != "":
messages.append(temp1)
messages.append(temp2)
else:
messages[-1]['content'] = temp2['content']
# 重新生成
if regenerate and chat_round:
messages.pop()
elif summary:
messages.append(compose_user(
"请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。"))
history = ["我们刚刚聊了什么?"]
else:
temp3 = {}
temp3["role"] = "user"
temp3["content"] = inputs
messages.append(temp3)
chat_round += 1
payload = {
"model": "gpt-3.5-turbo",
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
if not summary:
history.append(inputs)
# make a POST request to the API endpoint using the requests.post method, passing in stream=True
response = requests.post(API_URL, headers=headers,
json=payload, stream=True)
#response = requests.post(API_URL, headers=headers, json=payload, stream=True)
token_counter = 0
partial_words = ""
counter = 0
chatbot.append((history[-1], ""))
for chunk in response.iter_lines():
if counter == 0:
counter += 1
continue
counter += 1
# check whether each line is non-empty
if chunk:
# decode each line as response data is in bytes
try:
if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
break
except Exception as e:
chatbot.pop()
chatbot.append((history[-1], f"☹️发生了错误
返回值:{response.text}
异常:{e}"))
history.pop()
yield chatbot, history
break
#print(json.loads(chunk.decode()[6:])['choices'][0]["delta"] ["content"])
partial_words = partial_words + \
json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"]
if token_counter == 0:
history.append("" + partial_words)
else:
history[-1] = parse_text(partial_words)
chatbot[-1] = (history[-2], history[-1])
# chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
token_counter += 1
# resembles {chatbot: chat, state: history}
yield chatbot, history
print(chatbot)
# logs
# time = get_current_time()
# print(f"\n\ntime: {time}")
# # print(f"session_id: {session_id}")
# # print(f"chat_id: {chat_id}")
# print(f"user: {logged_in_user}")
# print(f"apiKey: {api_key}")
# print(f"system: {system_prompt}")
# print(f"round: {chat_round}")
# print(f"user: {inputs}")
history_write = history[-1].replace("'","\'").replace('"','\"').replace(" "," ").replace(")",")").replace("(","(").replace("!","!").replace("\\'","'")
print(f"assistant: {history_write}")
# print(f"messages: {messages}")
paras = dict(payload)
del paras["messages"]
# print(f"payload: {paras}")
# write data
write_data( time=time,
session_id=session_id,
chat_id=chat_id,
api_key=api_key,
round=chat_round,
system=system_prompt.replace("'","\'").replace('"','\"').replace('\n','').replace(' ',''),
user=inputs.replace("'","\'").replace('"','\"').replace('\n',''),
assistant=history_write.replace('\n',''),
messages='',
payload=paras,
username=logged_in_user
)
def delete_last_conversation(chatbot, history):
chatbot.pop()
history.pop()
history.pop()
return chatbot, history
def save_chat_history(filename, system, history, chatbot):
if filename == "":
return
if not filename.endswith(".json"):
filename = filename + get_mmdd_hhmi() + ".json"
os.makedirs(HISTORY_DIR, exist_ok=True)
json_s = {"system": system, "history": history, "chatbot": chatbot}
with open(os.path.join(HISTORY_DIR, filename), "w", encoding="UTF-8") as f:
json.dump(json_s, f, ensure_ascii=False)
# load chat history item
def load_chat_history(filename):
try:
with open(os.path.join(HISTORY_DIR, filename), "r", encoding="UTF-8") as f:
json_s = json.load(f)
except:
with open(os.path.join(HISTORY_DIR, "default.json"), "r", encoding="UTF-8") as f:
json_s = json.load(f)
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
# update dropdown list with json files list in dir
def get_file_names(dir, plain=False, filetype=".json"):
# find all json files in the current directory and return their names
try:
files = [f for f in os.listdir(dir) if f.endswith(filetype)]
except FileNotFoundError:
files = []
if plain:
return files
else:
return gr.Dropdown.update(choices=files)
# construct history files list
def get_history_names(plain=False):
return get_file_names(HISTORY_DIR, plain)
# deprecated: load templates
def load_template(filename):
lines = []
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="UTF-8") as csvfile:
reader = csv.reader(csvfile)
lines = list(reader)
lines = lines[1:]
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=[row[0] for row in lines])
def get_template_names(plain=False):
return get_file_names(TEMPLATES_DIR, plain, filetype=".csv")
def reset_state():
return [], []
def compose_system(system_prompt):
return {"role": "system", "content": system_prompt}
def compose_user(user_input):
return {"role": "user", "content": user_input}
def reset_textbox():
return gr.update(value='')
with open("styles.css", "r", encoding="utf-8") as f:
css_styles = f.read()
# build UI
with gr.Blocks(css=css_styles,title="Chatbot🚀"
,theme=gr.themes.Soft()
) as interface:
history = gr.State([])
promptTemplates = gr.State({})
TRUECONSTANT = gr.State(True)
FALSECONSTANT = gr.State(False)
with gr.Column(elem_id="col-container"):
gr.Markdown("""# Chatbot Playground 🚀
本对话使用 GPT-3.5 Turbo 模型。如果遇到问题请联系Simon!
""",
elem_id="header")
with gr.Row():
# left col: chat
with gr.Column():
# clear history
btnClearChat = gr.Button("🧹 清空对话")# refresh chat_id
# dialogue
chatbot = gr.Chatbot(elem_id="chatbox",label="对话") # .style(color_map=("#1D51EE", "#585A5B"))
# input
with gr.Row():
with gr.Column(scale=12):
txt = gr.Textbox(show_label=False, placeholder="在这里输入,Shift+Enter换行").style(
container=False)
with gr.Column(min_width=50, scale=1):
btnSubmit = gr.Button("🚀", variant="primary")
# shortcuts
with gr.Row():
btnRegen = gr.Button("🔄 重新生成")
btnDelPrev = gr.Button("🗑️ 删除上一轮")
btnSummarize = gr.Button("♻️ 总结对话")
# session id
with gr.Row(elem_id="ids"):
mdSessionId = gr.Textbox(label="Session ID",value=gen_uuid,elem_id="session_id",interactive=False)
mdChatId = gr.Textbox(label="Chat ID",value=gen_uuid,elem_id="chat_id",interactive=False)
# right col: settings
with gr.Column():
# api key
with gr.Accordion(label="API Key(帮助节省额度)", open=False):
apiKey = gr.Textbox(show_label=True, placeholder=f"使用自己的OpenAI API key",
value='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', label="默认的API额度可能很快用完,你可以使用自己的API Key,以帮忙节省额度。", type="password", visible=True).style(container=True)
# sys prompt
promptSystem = gr.Textbox(lines=2,show_label=True, placeholder=f"在这里输入System Prompt...",
label="系统输入", info="系统输入不会直接产生交互,可作为基础信息输入。", value=gen_sys_prompt).style(container=True)
topic = gr.State("未命名对话")
# templates
# with gr.Accordion(label="加载模板", open=False):
# with gr.Column():
# with gr.Row():
# with gr.Column(scale=8):
# templateFileSelectDropdown = gr.Dropdown(label="1. 选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True),value="prompt.csv", multiselect=False)
# with gr.Column(scale=1):
# templateRefreshBtn = gr.Button("🔄 刷新列表")
# templaeFileReadBtn = gr.Button("📂 读取文件")
# with gr.Row():
# with gr.Column(scale=6):
# templateSelectDropdown = gr.Dropdown(label="2. 选择prompt", choices=[], multiselect=False)
# with gr.Column(scale=1):
# templateApplyBtn = gr.Button("⬇️ 引用prompt")
# load / save chat
with gr.Accordion(label="导入/导出", open=False):
# hint
# gr.Markdown("注意:导出末尾不含 `.json` 则自动添加时间戳`_mmdd_hhmi`,文件名一致会进行覆盖!")
with gr.Row():
# load
with gr.Column():
historyFileSelectDropdown = gr.Dropdown(label="加载对话", choices=get_history_names(plain=True), value="model_抖音.json", multiselect=False)
btnLoadHist = gr.Button("📂 导入对话")
btnGetHistList = gr.Button("🔄 刷新列表")
# save
with gr.Column():
saveFileName = gr.Textbox(show_label=True, placeholder=f"文件名", label="导出文件名", value="对话").style(container=True)
btnSaveToHist = gr.Button("💾 导出对话")
# parameters: inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("参数配置", open=True):
gr.Markdown("""
* [GPT-3参数详解](https://blog.csdn.net/jarodyv/article/details/128984602)
* [API文档](https://platform.openai.com/docs/api-reference/chat/create)
""")
max_tokens = gr.Slider(minimum=100, maximum=2500, value=1500,
step=10, label="最大token数",info="每次回答最大token数,数值过大可能长篇大论,看起来不像正常聊天")
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.5, step=0.05,
interactive=True, label="Top-p (核采样)",
info="0.1意味着只考虑包含最高10%概率质量的token。较小时,更加紧凑连贯,但缺乏创造性。较大时,更加多样化,但会出现语法错误和不连贯。")
temperature = gr.Slider(minimum=-0, maximum=2.0, value=1,
step=0.1, interactive=True, label="采样温度",info="越高 → 不确定性、创造性越高↑")
# context_length = gr.Slider(minimum=1, maximum=5, value=2,
# step=1, label="记忆轮数", info="每次用于记忆的对话轮数。注意:数值过大可能会导致token数巨大!")
presence_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="出现惩罚系数", info="越高 → 增加谈论新主题的可能性" )
frequency_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="频率惩罚系数", info="越高 → 降低模型逐字重复同一行的可能性" )
# gr.Markdown(description)
# event listener: enter
btnClearChat.click(gen_uuid, [], [mdChatId])
txt.submit(openai_request,
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history])
txt.submit(reset_textbox, [], [txt])
# button: submit
btnSubmit.click(openai_request,
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history], show_progress=True)
btnSubmit.click(reset_textbox, [], [txt])
# button: clear history
btnClearChat.click(reset_state, outputs=[chatbot, history])
# button: regenerate
btnRegen.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history,
promptSystem, TRUECONSTANT], [chatbot, history], show_progress=True)
# button: delete previous round
btnDelPrev.click(delete_last_conversation, [chatbot, history], [
chatbot, history], show_progress=True)
# button: summarize
btnSummarize.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history,
promptSystem, FALSECONSTANT, TRUECONSTANT], [chatbot, history], show_progress=True)
# button: save dialogue (save hist & get hist name)
btnSaveToHist.click(save_chat_history, [
saveFileName, promptSystem, history, chatbot], None, show_progress=True)
btnSaveToHist.click(get_history_names, None, [historyFileSelectDropdown])
# button: get history list
historyFileSelectDropdown.change(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot])
btnGetHistList.click(get_history_names, None, [historyFileSelectDropdown])
# button: load history
btnLoadHist.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot], show_progress=True)
# templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
# templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
# templateApplyBtn.click(lambda x, y: x[y], [promptTemplates, templateSelectDropdown], [promptSystem], show_progress=True)
# print("访问:http://localhost:7860")
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
#if running in Docker
if dockerflag:
interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860,auth=auth_check)
#if not running in Docker
else:
interface.queue(concurrency_count=2,status_update_rate="auto").launch(
height='800px',
debug=True,
show_error=True,
auth=auth_check,
auth_message="请联系Simon获取用户名与密码!"
)