# -- coding: utf-8 -- import json import re import gradio as gr import os import sys import requests import csv import datetime, pytz import uuid import pymysql # init variables user_key = "" # 在这里输入你的 API 密钥 # username = os.environ.get('user') # password = os.environ.get('pass') API_URL = "https://api.openai.com/v1/chat/completions" API_URL = "https://yuns.deno.dev/v1/chat/completions" API_URL = "https://openai.hex.im/v1/chat/completions" HISTORY_DIR = "history" TEMPLATES_DIR = "templates" # initial_prompt = f"现在是{get_current_time()}。你的目的是如实解答用户问题,对于不知道的问题,你需要回复你不知道。有些答案存在时效性。表格内容直接以HTML格式输出。" def gen_sys_prompt(): tz = pytz.timezone('Asia/Shanghai') now = datetime.datetime.now(tz) timestamp = now.strftime("%Y年%m月%d日 %H:%M") sys_prompt = f"""你的目的是如实解答用户问题,对于不知道的问题,你需要回复你不知道。 现在是{timestamp},有些信息可能已经过时。 表格的效果展示直接以HTML格式输出而非markdown格式。""" return sys_prompt # get timestamp def get_mmdd_hhmi(): tz = pytz.timezone('Asia/Shanghai') now = datetime.datetime.now(tz) timestamp = now.strftime("_%m%d_%H%M") return timestamp # get current time def get_current_time(): tz = pytz.timezone('Asia/Shanghai') now = datetime.datetime.now(tz) timestamp = now.strftime("%Y-%m-%d %H:%M:%S") return timestamp # db fetch def fetch_data(query): # connect to db conn = pymysql.connect(host = os.environ.get('db_host'), port=3306, user = os.environ.get('db_user'), password = os.environ.get('db_pass'), db = os.environ.get('db_db'), charset = 'utf8mb4') cur = conn.cursor() cur.execute(query) result = cur.fetchall() cur.close() conn.close() return result # get user_tuple try: user_tuple = fetch_data("SELECT username, password FROM credentials") source = "db" except: user_tuple = eval(os.environ.get('user_tuple')) source = "env" print(f"{source}: {user_tuple}") # auth check def auth_check(username, password): global logged_in_user if_pass = False for user in user_tuple: # print(user) if user[0] == username and user[1] == password: if_pass = True logged_in_user = username break if if_pass: print(f"Logged in as {logged_in_user}") else: print(f"Login attempt failed:[{username},{password}]") return if_pass # db write def write_data(time, session_id, chat_id, api_key, round, system, user, assistant, messages, payload, username): # connect to db conn = pymysql.connect(host = os.environ.get('db_host'), port=3306, user = os.environ.get('db_user'), password = os.environ.get('db_pass'), db = os.environ.get('db_db'), charset = 'utf8mb4') # create cursor cur = conn.cursor() # SQL sql_update = f''' INSERT INTO `chatlog` (`time`,`session_id`,`chat_id`,`api_key`,`round`,`system`,`user`,`assistant`,`messages`,`payload`,`username`) VALUES ("{time}","{session_id}","{chat_id}","{api_key}","{round}","{system}","{user}","{assistant}","{messages}","{payload}","{username}") ''' print(sql_update) try: # insert data by update cur.execute(sql_update) conn.commit() print('成功写入数据!') except Exception as e: conn.rollback() print(f"出错了:{e}") # close conn cur.close() conn.close() # clear state def get_empty_state(): return {"total_tokens": 0, "messages": []} # uuid genetator def gen_uuid(): return str(uuid.uuid4()) #if we are running in Docker if os.environ.get('dockerrun') == 'yes': dockerflag = True else: dockerflag = False if dockerflag: my_api_key = os.environ.get('OPENAI_API_KEY') if my_api_key == "empty": print("Please give a api key!") sys.exit(1) #auth # username = os.environ.get('user') # password = os.environ.get('pass') # if isinstance(username, type(None)) or isinstance(password, type(None)): # authflag = False # else: # authflag = True # parse text for code def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 # locate the start & end of the code block in_code_block = False # track if we're currently in a code block for i, line in enumerate(lines): # code block indicator if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: # start lines[i] = f"
"
                in_code_block = True
            else: # end
                lines[i] = f'
' in_code_block = False # code block content elif in_code_block: line = line.replace("&", "&") # line = line.replace("\"", "`\"`") # line = line.replace("\'", "`\'`") line = line.replace("'","\'") line = line.replace('"','\"') line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("#", "#") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") lines[i] = line + "
" # normal text else: line += "\n" # print(f"line: {line}") # print(f"lines: {lines}") text = "".join(lines) text = re.sub(r'
+
', '
', text) return text # 请求API def openai_request(inputs, top_p, temperature, max_tokens, user_key, session_id, chat_id, chatbot=[], history=[], system_prompt=gen_sys_prompt, regenerate=False, summary=False): # repetition_penalty, top_k # user key > env key api_key = user_key or os.environ.get("OPENAI_API_KEY") time = get_current_time() print(f"\n\ntime: {time}") # print(f"session_id: {session_id}") # print(f"chat_id: {chat_id}") print(f"user: {logged_in_user}") print(f"apiKey: {api_key}") sys_prompt_print = system_prompt.replace("\n","").replace(" ","") print(f"system: {sys_prompt_print}") print(f"user: {inputs}") # 对话轮数 = list 除 2,初始为 0 chat_round = len(history) // 2 print(f"round: {chat_round}") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } # 构造 messages list messages = [compose_system(system_prompt)] # 初始状态(chat_round == 0 == False) if chat_round: # 不为 0 时 for data in chatbot: temp1 = {} temp1["role"] = "user" temp1["content"] = data[0] temp2 = {} temp2["role"] = "assistant" temp2["content"] = data[1] if temp1["content"] != "": messages.append(temp1) messages.append(temp2) else: messages[-1]['content'] = temp2['content'] # 重新生成 if regenerate and chat_round: messages.pop() elif summary: messages.append(compose_user( "请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。")) history = ["我们刚刚聊了什么?"] else: temp3 = {} temp3["role"] = "user" temp3["content"] = inputs messages.append(temp3) chat_round += 1 payload = { "model": "gpt-3.5-turbo", "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "n": 1, "stream": True, "presence_penalty": 0, "frequency_penalty": 0, } if not summary: history.append(inputs) # make a POST request to the API endpoint using the requests.post method, passing in stream=True response = requests.post(API_URL, headers=headers, json=payload, stream=True) #response = requests.post(API_URL, headers=headers, json=payload, stream=True) token_counter = 0 partial_words = "" counter = 0 chatbot.append((history[-1], "")) for chunk in response.iter_lines(): if counter == 0: counter += 1 continue counter += 1 # check whether each line is non-empty if chunk: # decode each line as response data is in bytes try: if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0: break except Exception as e: chatbot.pop() chatbot.append((history[-1], f"☹️发生了错误
返回值:{response.text}
异常:{e}")) history.pop() yield chatbot, history break #print(json.loads(chunk.decode()[6:])['choices'][0]["delta"] ["content"]) partial_words = partial_words + \ json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"] if token_counter == 0: history.append("" + partial_words) else: history[-1] = parse_text(partial_words) chatbot[-1] = (history[-2], history[-1]) # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list token_counter += 1 # resembles {chatbot: chat, state: history} yield chatbot, history print(chatbot) # logs # time = get_current_time() # print(f"\n\ntime: {time}") # # print(f"session_id: {session_id}") # # print(f"chat_id: {chat_id}") # print(f"user: {logged_in_user}") # print(f"apiKey: {api_key}") # print(f"system: {system_prompt}") # print(f"round: {chat_round}") # print(f"user: {inputs}") history_write = history[-1].replace("'","\'").replace('"','\"').replace(" "," ").replace(")",")").replace("(","(").replace("!","!").replace("\\'","'") print(f"assistant: {history_write}") # print(f"messages: {messages}") paras = dict(payload) del paras["messages"] # print(f"payload: {paras}") # write data write_data( time=time, session_id=session_id, chat_id=chat_id, api_key=api_key, round=chat_round, system=system_prompt.replace("'","\'").replace('"','\"').replace('\n','').replace(' ',''), user=inputs.replace("'","\'").replace('"','\"').replace('\n',''), assistant=history_write.replace('\n',''), messages='', payload=paras, username=logged_in_user ) def delete_last_conversation(chatbot, history): chatbot.pop() history.pop() history.pop() return chatbot, history def save_chat_history(filename, system, history, chatbot): if filename == "": return if not filename.endswith(".json"): filename = filename + get_mmdd_hhmi() + ".json" os.makedirs(HISTORY_DIR, exist_ok=True) json_s = {"system": system, "history": history, "chatbot": chatbot} with open(os.path.join(HISTORY_DIR, filename), "w", encoding="UTF-8") as f: json.dump(json_s, f, ensure_ascii=False) # load chat history item def load_chat_history(filename): try: with open(os.path.join(HISTORY_DIR, filename), "r", encoding="UTF-8") as f: json_s = json.load(f) except: with open(os.path.join(HISTORY_DIR, "default.json"), "r", encoding="UTF-8") as f: json_s = json.load(f) return filename, json_s["system"], json_s["history"], json_s["chatbot"] # update dropdown list with json files list in dir def get_file_names(dir, plain=False, filetype=".json"): # find all json files in the current directory and return their names try: files = [f for f in os.listdir(dir) if f.endswith(filetype)] except FileNotFoundError: files = [] if plain: return files else: return gr.Dropdown.update(choices=files) # construct history files list def get_history_names(plain=False): return get_file_names(HISTORY_DIR, plain) # deprecated: load templates def load_template(filename): lines = [] with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="UTF-8") as csvfile: reader = csv.reader(csvfile) lines = list(reader) lines = lines[1:] return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=[row[0] for row in lines]) def get_template_names(plain=False): return get_file_names(TEMPLATES_DIR, plain, filetype=".csv") def reset_state(): return [], [] def compose_system(system_prompt): return {"role": "system", "content": system_prompt} def compose_user(user_input): return {"role": "user", "content": user_input} def reset_textbox(): return gr.update(value='') with open("styles.css", "r", encoding="utf-8") as f: css_styles = f.read() # build UI with gr.Blocks(css=css_styles,title="Chatbot🚀" ,theme=gr.themes.Soft() ) as interface: history = gr.State([]) promptTemplates = gr.State({}) TRUECONSTANT = gr.State(True) FALSECONSTANT = gr.State(False) with gr.Column(elem_id="col-container"): gr.Markdown("""# Chatbot Playground 🚀 本对话使用 GPT-3.5 Turbo 模型。如果遇到问题请联系Simon! """, elem_id="header") with gr.Row(): # left col: chat with gr.Column(): # clear history btnClearChat = gr.Button("🧹 清空对话")# refresh chat_id # dialogue chatbot = gr.Chatbot(elem_id="chatbox",label="对话") # .style(color_map=("#1D51EE", "#585A5B")) # input with gr.Row(): with gr.Column(scale=12): txt = gr.Textbox(show_label=False, placeholder="在这里输入,Shift+Enter换行").style( container=False) with gr.Column(min_width=50, scale=1): btnSubmit = gr.Button("🚀", variant="primary") # shortcuts with gr.Row(): btnRegen = gr.Button("🔄 重新生成") btnDelPrev = gr.Button("🗑️ 删除上一轮") btnSummarize = gr.Button("♻️ 总结对话") # session id with gr.Row(elem_id="ids"): mdSessionId = gr.Textbox(label="Session ID",value=gen_uuid,elem_id="session_id",interactive=False) mdChatId = gr.Textbox(label="Chat ID",value=gen_uuid,elem_id="chat_id",interactive=False) # right col: settings with gr.Column(): # api key with gr.Accordion(label="API Key(帮助节省额度)", open=False): apiKey = gr.Textbox(show_label=True, placeholder=f"使用自己的OpenAI API key", value='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', label="默认的API额度可能很快用完,你可以使用自己的API Key,以帮忙节省额度。", type="password", visible=True).style(container=True) # sys prompt promptSystem = gr.Textbox(lines=2,show_label=True, placeholder=f"在这里输入System Prompt...", label="系统输入", info="系统输入不会直接产生交互,可作为基础信息输入。", value=gen_sys_prompt).style(container=True) topic = gr.State("未命名对话") # templates # with gr.Accordion(label="加载模板", open=False): # with gr.Column(): # with gr.Row(): # with gr.Column(scale=8): # templateFileSelectDropdown = gr.Dropdown(label="1. 选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True),value="prompt.csv", multiselect=False) # with gr.Column(scale=1): # templateRefreshBtn = gr.Button("🔄 刷新列表") # templaeFileReadBtn = gr.Button("📂 读取文件") # with gr.Row(): # with gr.Column(scale=6): # templateSelectDropdown = gr.Dropdown(label="2. 选择prompt", choices=[], multiselect=False) # with gr.Column(scale=1): # templateApplyBtn = gr.Button("⬇️ 引用prompt") # load / save chat with gr.Accordion(label="导入/导出", open=False): # hint # gr.Markdown("注意:导出末尾不含 `.json` 则自动添加时间戳`_mmdd_hhmi`,文件名一致会进行覆盖!") with gr.Row(): # load with gr.Column(): historyFileSelectDropdown = gr.Dropdown(label="加载对话", choices=get_history_names(plain=True), value="model_抖音.json", multiselect=False) btnLoadHist = gr.Button("📂 导入对话") btnGetHistList = gr.Button("🔄 刷新列表") # save with gr.Column(): saveFileName = gr.Textbox(show_label=True, placeholder=f"文件名", label="导出文件名", value="对话").style(container=True) btnSaveToHist = gr.Button("💾 导出对话") # parameters: inputs, top_p, temperature, top_k, repetition_penalty with gr.Accordion("参数配置", open=True): gr.Markdown(""" * [GPT-3参数详解](https://blog.csdn.net/jarodyv/article/details/128984602) * [API文档](https://platform.openai.com/docs/api-reference/chat/create) """) max_tokens = gr.Slider(minimum=100, maximum=2500, value=1500, step=10, label="最大token数",info="每次回答最大token数,数值过大可能长篇大论,看起来不像正常聊天") top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.5, step=0.05, interactive=True, label="Top-p (核采样)", info="0.1意味着只考虑包含最高10%概率质量的token。较小时,更加紧凑连贯,但缺乏创造性。较大时,更加多样化,但会出现语法错误和不连贯。") temperature = gr.Slider(minimum=-0, maximum=2.0, value=1, step=0.1, interactive=True, label="采样温度",info="越高 → 不确定性、创造性越高↑") # context_length = gr.Slider(minimum=1, maximum=5, value=2, # step=1, label="记忆轮数", info="每次用于记忆的对话轮数。注意:数值过大可能会导致token数巨大!") presence_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="出现惩罚系数", info="越高 → 增加谈论新主题的可能性" ) frequency_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0, step=0.1, interactive=True, label="频率惩罚系数", info="越高 → 降低模型逐字重复同一行的可能性" ) # gr.Markdown(description) # event listener: enter btnClearChat.click(gen_uuid, [], [mdChatId]) txt.submit(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history]) txt.submit(reset_textbox, [], [txt]) # button: submit btnSubmit.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem], [chatbot, history], show_progress=True) btnSubmit.click(reset_textbox, [], [txt]) # button: clear history btnClearChat.click(reset_state, outputs=[chatbot, history]) # button: regenerate btnRegen.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem, TRUECONSTANT], [chatbot, history], show_progress=True) # button: delete previous round btnDelPrev.click(delete_last_conversation, [chatbot, history], [ chatbot, history], show_progress=True) # button: summarize btnSummarize.click(openai_request, [txt, top_p, temperature, max_tokens, apiKey, mdSessionId, mdChatId, chatbot, history, promptSystem, FALSECONSTANT, TRUECONSTANT], [chatbot, history], show_progress=True) # button: save dialogue (save hist & get hist name) btnSaveToHist.click(save_chat_history, [ saveFileName, promptSystem, history, chatbot], None, show_progress=True) btnSaveToHist.click(get_history_names, None, [historyFileSelectDropdown]) # button: get history list historyFileSelectDropdown.change(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot]) btnGetHistList.click(get_history_names, None, [historyFileSelectDropdown]) # button: load history btnLoadHist.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, promptSystem, history, chatbot], show_progress=True) # templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown]) # templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True) # templateApplyBtn.click(lambda x, y: x[y], [promptTemplates, templateSelectDropdown], [promptSystem], show_progress=True) # print("访问:http://localhost:7860") # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接 #if running in Docker if dockerflag: interface.queue(concurrency_count=2,status_update_rate="auto").launch(server_name="0.0.0.0", server_port=7860,auth=auth_check) #if not running in Docker else: interface.queue(concurrency_count=2,status_update_rate="auto").launch( height='800px', debug=True, show_error=True, auth=auth_check, auth_message="请联系Simon获取用户名与密码!" )