|
|
|
import json |
|
import re |
|
import gradio as gr |
|
import os |
|
import sys |
|
import requests |
|
import csv |
|
import datetime, pytz |
|
import uuid |
|
import pymysql |
|
|
|
|
|
|
|
user_key = "" |
|
|
|
|
|
|
|
API_URL = "https://api.openai.com/v1/chat/completions" |
|
API_URL = "https://yuns.deno.dev/v1/chat/completions" |
|
API_URL = "https://openai.hex.im/v1/chat/completions" |
|
HISTORY_DIR = "history" |
|
TEMPLATES_DIR = "templates" |
|
|
|
|
|
|
|
def gen_sys_prompt(): |
|
tz = pytz.timezone('Asia/Shanghai') |
|
now = datetime.datetime.now(tz) |
|
timestamp = now.strftime("%Y年%m月%d日 %H:%M") |
|
sys_prompt = f"""你的目的是如实解答用户问题,对于不知道的问题,你需要回复你不知道。 |
|
现在是{timestamp},有些信息可能已经过时。 |
|
表格的效果展示直接以HTML格式输出而非markdown格式。""" |
|
return sys_prompt |
|
|
|
|
|
def get_mmdd_hhmi(): |
|
tz = pytz.timezone('Asia/Shanghai') |
|
now = datetime.datetime.now(tz) |
|
timestamp = now.strftime("_%m%d_%H%M") |
|
return timestamp |
|
|
|
|
|
def get_current_time(): |
|
tz = pytz.timezone('Asia/Shanghai') |
|
now = datetime.datetime.now(tz) |
|
timestamp = now.strftime("%Y-%m-%d %H:%M:%S") |
|
return timestamp |
|
|
|
|
|
|
|
def fetch_data(query): |
|
|
|
conn = pymysql.connect(host = os.environ.get('db_host'), port=3306, |
|
user = os.environ.get('db_user'), |
|
password = os.environ.get('db_pass'), |
|
db = os.environ.get('db_db'), |
|
charset = 'utf8mb4') |
|
cur = conn.cursor() |
|
cur.execute(query) |
|
result = cur.fetchall() |
|
cur.close() |
|
conn.close() |
|
return result |
|
|
|
|
|
try: |
|
user_tuple = fetch_data("SELECT username, password FROM credentials") |
|
source = "db" |
|
except: |
|
user_tuple = eval(os.environ.get('user_tuple')) |
|
source = "env" |
|
|
|
print(f"{source}: {user_tuple}") |
|
|
|
|
|
def auth_check(username, password): |
|
global logged_in_user |
|
if_pass = False |
|
for user in user_tuple: |
|
|
|
if user[0] == username and user[1] == password: |
|
if_pass = True |
|
logged_in_user = username |
|
break |
|
if if_pass: |
|
print(f"Logged in as {logged_in_user}") |
|
else: |
|
print(f"Login attempt failed:[{username},{password}]") |
|
return if_pass |
|
|
|
|
|
|
|
def write_data(time, session_id, chat_id, api_key, round, system, user, assistant, messages, payload, username): |
|
|
|
|
|
conn = pymysql.connect(host = os.environ.get('db_host'), port=3306, |
|
user = os.environ.get('db_user'), |
|
password = os.environ.get('db_pass'), |
|
db = os.environ.get('db_db'), |
|
charset = 'utf8mb4') |
|
|
|
cur = conn.cursor() |
|
|
|
sql_update = f''' |
|
INSERT INTO `chatlog` (`time`,`session_id`,`chat_id`,`api_key`,`round`,`system`,`user`,`assistant`,`messages`,`payload`,`username`) |
|
VALUES ("{time}","{session_id}","{chat_id}","{api_key}","{round}","{system}","{user}","{assistant}","{messages}","{payload}","{username}") |
|
''' |
|
print(sql_update) |
|
try: |
|
|
|
cur.execute(sql_update) |
|
conn.commit() |
|
print('成功写入数据!') |
|
except Exception as e: |
|
conn.rollback() |
|
print(f"出错了:{e}") |
|
|
|
cur.close() |
|
conn.close() |
|
|
|
|
|
def get_empty_state(): |
|
return {"total_tokens": 0, "messages": []} |
|
|
|
|
|
def gen_uuid(): |
|
return str(uuid.uuid4()) |
|
|
|
|
|
if os.environ.get('dockerrun') == 'yes': |
|
dockerflag = True |
|
else: |
|
dockerflag = False |
|
|
|
if dockerflag: |
|
my_api_key = os.environ.get('OPENAI_API_KEY') |
|
if my_api_key == "empty": |
|
print("Please give a api key!") |
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_text(text): |
|
lines = text.split("\n") |
|
lines = [line for line in lines if line != ""] |
|
count = 0 |
|
in_code_block = False |
|
|
|
for i, line in enumerate(lines): |
|
|
|
|
|
if "```" in line: |
|
count += 1 |
|
items = line.split('`') |
|
if count % 2 == 1: |
|
lines[i] = f"<pre><code class='language-{items[-1]}'>" |
|
in_code_block = True |
|
else: |
|
lines[i] = f'</code></pre>' |
|
in_code_block = False |
|
|
|
|
|
elif in_code_block: |
|
line = line.replace("&", "&") |
|
|
|
|
|
line = line.replace("'","\'") |
|
line = line.replace('"','\"') |
|
line = line.replace("<", "<") |
|
line = line.replace(">", ">") |
|
line = line.replace(" ", " ") |
|
line = line.replace("*", "*") |
|
line = line.replace("_", "_") |
|
line = line.replace("#", "#") |
|
line = line.replace("-", "-") |
|
line = line.replace(".", ".") |
|
line = line.replace("!", "!") |
|
line = line.replace("(", "(") |
|
line = line.replace(")", ")") |
|
lines[i] = line + "<br>" |
|
|
|
|
|
else: |
|
line += "\n" |
|
|
|
|
|
text = "".join(lines) |
|
text = re.sub(r'<br>+<br>', '<br>', text) |
|
return text |
|
|
|
|
|
|
|
def openai_request(inputs, top_p, temperature, max_tokens, user_key, session_id, chat_id, chatbot=[], history=[], system_prompt=gen_sys_prompt, regenerate=False, summary=False): |
|
|
|
|
|
api_key = user_key or os.environ.get("OPENAI_API_KEY") |
|
|
|
time = get_current_time() |
|
print(f"\n\ntime: {time}") |
|
|
|
|
|
print(f"user: {logged_in_user}") |
|
print(f"apiKey: {api_key}") |
|
sys_prompt_print = system_prompt.replace("\n","").replace(" ","") |
|
print(f"system: {sys_prompt_print}") |
|
print(f"user: {inputs}") |
|
|
|
|
|
|
|
chat_round = len(history) // 2 |
|
print(f"round: {chat_round}") |
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {api_key}" |
|
} |
|
|
|
|
|
messages = [compose_system(system_prompt)] |
|
|
|
|
|
if chat_round: |
|
for data in chatbot: |
|
temp1 = {} |
|
temp1["role"] = "user" |
|
temp1["content"] = data[0] |
|
temp2 = {} |
|
temp2["role"] = "assistant" |
|
temp2["content"] = data[1] |
|
if temp1["content"] != "": |
|
messages.append(temp1) |
|
messages.append(temp2) |
|
else: |
|
messages[-1]['content'] = temp2['content'] |
|
|
|
|
|
if regenerate and chat_round: |
|
messages.pop() |
|
elif summary: |
|
messages.append(compose_user( |
|
"请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。")) |
|
history = ["我们刚刚聊了什么?"] |
|
else: |
|
temp3 = {} |
|
temp3["role"] = "user" |
|
temp3["content"] = inputs |
|
messages.append(temp3) |
|
chat_round += 1 |
|
|
|
payload = { |
|
"model": "gpt-3.5-turbo", |
|
"messages": messages, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens, |
|
"top_p": top_p, |
|
"n": 1, |
|
"stream": True, |
|
"presence_penalty": 0, |
|
"frequency_penalty": 0, |
|
} |
|
|
|
if not summary: |
|
history.append(inputs) |
|
|
|
response = requests.post(API_URL, headers=headers, |
|
json=payload, stream=True) |
|
|
|
|
|
token_counter = 0 |
|
partial_words = "" |
|
|
|
counter = 0 |
|
chatbot.append((history[-1], "")) |
|
|
|
for chunk in response.iter_lines(): |
|
if counter == 0: |
|
counter += 1 |
|
continue |
|
counter += 1 |
|
|
|
if chunk: |
|
|
|
try: |
|
if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0: |
|
break |
|
except Exception as e: |
|
chatbot.pop() |
|
chatbot.append((history[-1], f"☹️发生了错误<br>返回值:{response.text}<br>异常:{e}")) |
|
history.pop() |
|
yield chatbot, history |
|
break |
|
|
|
partial_words = partial_words + \ |
|
json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"] |
|
if token_counter == 0: |
|
history.append("" + partial_words) |
|
else: |
|
history[-1] = parse_text(partial_words) |
|
chatbot[-1] = (history[-2], history[-1]) |
|
|
|
token_counter += 1 |
|
|
|
yield chatbot, history |
|
|
|
print(chatbot) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history_write = history[-1].replace("'","\'").replace('"','\"').replace(" "," ").replace(")",")").replace("(","(").replace("!","!").replace("\\'","'") |
|
print(f"assistant: {history_write}") |
|
|
|
paras = dict(payload) |
|
del paras["messages"] |
|
|
|
|
|
|
|
write_data( time=time, |
|
session_id=session_id, |
|
chat_id=chat_id, |
|
api_key=api_key, |
|
round=chat_round, |
|
system=system_prompt.replace("'","\'").replace('"','\"').replace('\n','').replace(' ',''), |
|
user=inputs.replace("'","\'").replace('"','\"').replace('\n',''), |
|
assistant=history_write.replace('\n',''), |
|
messages='', |
|
payload=paras, |
|
username=logged_in_user |
|
) |
|
|
|
|
|
|
|
def delete_last_conversation(chatbot, history): |
|
chatbot.pop() |
|
history.pop() |
|
history.pop() |
|
return chatbot, history |
|
|
|
def save_chat_history(filename, system, history, chatbot): |
|
if filename == "": |
|
return |
|
if not filename.endswith(".json"): |
|
filename = filename + get_mmdd_hhmi() + ".json" |
|
os.makedirs(HISTORY_DIR, exist_ok=True) |
|
json_s = {"system": system, "history": history, "chatbot": chatbot} |
|
with open(os.path.join(HISTORY_DIR, filename), "w", encoding="UTF-8") as f: |
|
json.dump(json_s, f, ensure_ascii=False) |
|
|
|
|
|
def load_chat_history(filename): |
|
try: |
|
with open(os.path.join(HISTORY_DIR, filename), "r", encoding="UTF-8") as f: |
|
json_s = json.load(f) |
|
except: |
|
with open(os.path.join(HISTORY_DIR, "default.json"), "r", encoding="UTF-8") as f: |
|
json_s = json.load(f) |
|
return filename, json_s["system"], json_s["history"], json_s["chatbot"] |
|
|
|
|
|
def get_file_names(dir, plain=False, filetype=".json"): |
|
|
|
try: |
|
files = [f for f in os.listdir(dir) if f.endswith(filetype)] |
|
except FileNotFoundError: |
|
files = [] |
|
if plain: |
|
return files |
|
else: |
|
return gr.Dropdown.update(choices=files) |
|
|
|
|
|
def get_history_names(plain=False): |
|
return get_file_names(HISTORY_DIR, plain) |
|
|
|
|
|
def load_template(filename): |
|
lines = [] |
|
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="UTF-8") as csvfile: |
|
reader = csv.reader(csvfile) |
|
lines = list(reader) |
|
lines = lines[1:] |
|
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=[row[0] for row in lines]) |
|
|
|
def get_template_names(plain=False): |
|
return get_file_names(TEMPLATES_DIR, plain, filetype=".csv") |
|
|
|
def reset_state(): |
|
return [], [] |
|
|
|
|
|
def compose_system(system_prompt): |
|
return {"role": "system", "content": system_prompt} |
|
|
|
|
|
def compose_user(user_input): |
|
return {"role": "user", "content": user_input} |
|
|
|
|
|
def reset_textbox(): |
|
return gr.update(value='') |
|
|
|
|
|
with open("styles.css", "r", encoding="utf-8") as f: |
|
css_styles = f.read() |
|
|
|
|
|
with gr.Blocks(css=css_styles,title="Chatbot🚀" |
|
,theme=gr.themes.Soft() |
|
) as interface: |
|
|
|
history = gr.State([]) |
|
promptTemplates = gr.State({}) |
|
TRUECONSTANT = gr.State(True) |
|
FALSECONSTANT = gr.State(False) |
|
|
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown("""# Chatbot Playground 🚀 |
|
本对话使用 GPT-3.5 Turbo 模型。如果遇到问题请联系Simon! |
|
""", |
|
elem_id="header") |
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
btnClearChat = gr.Button("🧹 清空对话") |
|
|
|
|
|
chatbot = gr.Chatbot(elem_id="chatbox",label="对话") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
txt = gr.Textbox(show_label=False, placeholder="在这里输入,Shift+Enter换行").style( |
|
container=False) |
|
with gr.Column(min_width=50, scale=1): |
|
btnSubmit = gr.Button("🚀", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
btnRegen = gr.Button("🔄 重新生成") |
|
btnDelPrev = gr.Button("🗑️ 删除上一轮") |
|
btnSummarize = gr.Button("♻️ 总结对话") |
|
|
|
|
|
with gr.Row(elem_id="ids"): |
|
mdSessionId = gr.Textbox(label="Session ID",value=gen_uuid,elem_id="session_id",interactive=False) |
|
mdChatId = gr.Textbox(label="Chat ID",value=gen_uuid,elem_id="chat_id",interactive=False) |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
with gr.Accordion(label="API Key(帮助节省额度)", open=False): |
|
apiKey = gr.Textbox(show_label=True, placeholder=f"使用自己的OpenAI API key", |
|
value='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', label="默认的API额度可能很快用完,你可以使用自己的API Key,以帮忙节省额度。", type="password", visible=True).style(container=True) |
|
|
|
|
|
promptSystem = gr.Textbox(lines=2,show_label=True, placeholder=f"在这里输入System Prompt...", |
|
label="系统输入", info="系统输入不会直接产生交互,可作为基础信息输入。", value=gen_sys_prompt).style(container=True) |
|
|
|
topic = gr.State("未命名对话") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Accordion(label="导入/导出", open=False): |
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
historyFileSelectDropdown = gr.Dropdown(label="加载对话", choices=get_history_names(plain=True), value="model_抖音.json", multiselect=False) |
|
btnLoadHist = gr.Button("📂 导入对话") |
|
btnGetHistList = gr.Button("🔄 刷新列表") |
|
|
|
with gr.Column(): |
|
saveFileName = gr.Textbox(show_label=True, placeholder=f"文件名", label="导出文件名", value="对话").style(container=True) |
|
btnSaveToHist = gr.Button("💾 导出对话") |
|
|
|
|
|
|
|
with gr.Accordion("参数配置", open=True): |
|
gr.Markdown(""" |
|
* [GPT-3参数详解](https://blog.csdn.net/jarodyv/article/details/128984602) |
|
* [API文档](https://platform.openai.com/docs/api-reference/chat/create) |
|
""") |
|
max_tokens = gr.Slider(minimum=100, maximum=2500, value=1500, |
|
step=10, label="最大token数",info="每次回答最大token数,数值过大可能长篇大论,看起来不像正常聊天") |
|
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.5, step=0.05, |
|
interactive=True, label="Top-p (核采样)", |
|
info="0.1意味着只考虑包含最高10%概率质量的token。较小时,更加紧凑连贯,但缺乏创造性。较大时,更加多样化,但会出现语法错误和不连贯。") |
|
temperature = gr.Slider(minimum=-0, maximum=2.0, value=1, |
|
step=0.1, interactive=True, label="采样温度",info="越高 → 不确定性、创造性越高↑") |
|
|
|
|
|
presence_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="出现惩罚系数", info="越高 → 增加谈论新主题的可能性" ) |
|
frequency_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="频率惩罚系数", info="越高 → 降低模型逐字重复同一行的可能性" ) |
|
|
|
|
|
|
|
btnClearChat.click(gen_uuid, [], [mdChatId]) |
|
|
|
txt.submit(openai_request, |
|
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history]) |
|
txt.submit(reset_textbox, [], [txt]) |
|
|
|
|
|
btnSubmit.click(openai_request, |
|
[txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history], show_progress=True) |
|
btnSubmit.click(reset_textbox, [], [txt]) |
|
|
|
|
|
btnClearChat.click(reset_state, outputs=[chatbot, history]) |
|
|
|
|
|
btnRegen.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, |
|
promptSystem, TRUECONSTANT], [chatbot, history], show_progress=True) |
|
|
|
|
|
btnDelPrev.click(delete_last_conversation, [chatbot, history], [ |
|
chatbot, history], show_progress=True) |
|
|
|
|
|
btnSummarize.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, |
|
promptSystem, FALSECONSTANT, TRUECONSTANT], [chatbot, history], show_progress=True) |
|
|
|
|
|
btnSaveToHist.click(save_chat_history, [ |
|
saveFileName, promptSystem, history, chatbot], None, show_progress=True) |
|
btnSaveToHist.click(get_history_names, None, [historyFileSelectDropdown]) |
|
|
|
|
|
historyFileSelectDropdown.change(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot]) |
|
btnGetHistList.click(get_history_names, None, [historyFileSelectDropdown]) |
|
|
|
|
|
btnLoadHist.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot], show_progress=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dockerflag: |
|
interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860,auth=auth_check) |
|
|
|
else: |
|
interface.queue(concurrency_count=2,status_update_rate="auto").launch( |
|
height='800px', |
|
debug=True, |
|
show_error=True, |
|
auth=auth_check, |
|
auth_message="请联系Simon获取用户名与密码!" |
|
) |
|
|