FinBot / app.py
xl2533's picture
global faiss
9143731
import gradio as gr
import os
import json
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './indexer/{}'
docsearch = None
def gen_conversation(conversations):
messages = []
for data in conversations:
temp1 = {}
temp1["role"] = "user"
temp1["content"] = data[0]
temp2 = {}
temp2["role"] = "assistant"
temp2["content"] = data[1]
messages.append(temp1)
messages.append(temp2)
return messages
def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic,
chat_counter, chatbot=[], history=[]):
global docsearch
topic = topic[0]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai_api_key}"
}
print(f"chat_counter - {chat_counter}")
print(f'Histroy - {history}') # History: Original Input and Output in flatten list
print(f'chatbot - {chatbot}') # Chat Bot: 上一轮回复的[[user, AI]]
history.append(inputs)
# Debugging
if enable_index:
# Faiss 检索最近的embedding
store = faiss_store.format(topic)
if docsearch is None:
print('Loading FAISS')
docsearch = FAISS.load_local(store, OpenAIEmbeddings(openai_api_key=openai_api_key))
else:
print('Faiss already loaded')
# 构建模板
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
messages_combine = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
messages_reduce = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine}
)
result = chain({"query": inputs})
print(result)
result = result['result']
# 生成返回值
history.append(result)
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
chat_counter += 1
yield chat, history, chat_counter
else:
if chat_counter == 0:
messages = [{"role": "user", "content": f"{inputs}"}]
else:
# 如果有历史对话,把对话拼接进入上下文
messages = gen_conversation(chatbot)
messages.append({'role': 'user', 'content': inputs})
# messages
payload = {
"model": "gpt-3.5-turbo",
"messages": messages, # [{"role": "user", "content": f"{inputs}"}],
"temperature": temperature, # 1.0,
"top_p": top_p, # 1.0,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
print(f"payload is - {payload}")
chat_counter += 1
# 请求OpenAI
response = requests.post(API_URL, headers=headers, json=payload, stream=True)
token_counter = 0
partial_words = ""
# 逐字返回
counter = 0
for chunk in response.iter_lines():
if counter == 0:
counter += 1
continue
counter += 1
# check whether each line is non-empty
if chunk:
# decode each line as response data is in bytes
delta = json.loads(chunk.decode()[6:])['choices'][0]["delta"]
if len(delta) == 0:
break
partial_words += delta["content"]
# Keep Updating history
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
chat = [(history[i], history[i + 1]) for i in
range(0, len(history) - 1, 2)] # convert to tuples of list
token_counter += 1
yield chat, history, chat_counter
def reset_textbox():
return gr.update(value='')
with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
#chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
with gr.Column(elem_id="col_container"):
openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")
# inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("Parameters", open=True):
with gr.Row():
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.9, step=0.05, interactive=True,
label="Top-p (nucleus sampling)", )
temperature = gr.Slider(minimum=-0, maximum=5.0, value=0.8, step=0.1, interactive=True,
label="Temperature", )
max_tokens = gr.Slider(minimum=100, maximum=1000, value=200, step=100, interactive=True,
label="Max Tokens", )
chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
with gr.Row():
enable_index = gr.Checkbox(label='是', info='开启文档问答模式/聊天模式')
enable_search = gr.Checkbox(label='是', info='是否使用搜索')
topic = gr.CheckboxGroup(["两会", "数字经济", "硅谷银行"], label='使用文档索引')
chatbot = gr.Chatbot(elem_id='chatbot')
inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
state = gr.State([])
with gr.Row():
clear = gr.Button("Clear Conversation")
run = gr.Button("Run")
inputs.submit(predict,
[inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
state],
[chatbot, state, chat_counter], )
run.click(predict,
[inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
state],
[chatbot, state, chat_counter], )
# 每次对话结束都重置对话
clear.click(reset_textbox, [], [inputs], queue=False)
inputs.submit(reset_textbox, [], [inputs])
demo.queue().launch(debug=True)