|
import argparse |
|
import glob |
|
import json |
|
import os |
|
|
|
import gradio as gr |
|
from pipelines.document_stores import FAISSDocumentStore |
|
from pipelines.nodes import ( |
|
EmbeddingRetriever, |
|
ErnieBot, |
|
PDFToTextConverter, |
|
SpacyTextSplitter, |
|
) |
|
from pipelines.pipelines import Pipeline |
|
|
|
import erniebot |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--index_name", default="construct_demo_index", type=str, help="The ann index name of ANN." |
|
) |
|
parser.add_argument( |
|
"--file_paths", default="./construction_regulations", type=str, help="The PDF file path." |
|
) |
|
parser.add_argument("--retriever_top_k", default=5, type=int, help="Number of recall items for search") |
|
parser.add_argument( |
|
"--chunk_size", default=384, type=int, help="The length of data for indexing by retriever" |
|
) |
|
parser.add_argument("--host", type=str, default="localhost", help="host ip of ANN search engine") |
|
parser.add_argument("--port", type=int, default=8081, help="host ip of ANN search engine") |
|
parser.add_argument("--api_key", default=None, type=str, help="The API Key.") |
|
parser.add_argument("--secret_key", default=None, type=str, help="The secret key.") |
|
args = parser.parse_args() |
|
|
|
erniebot.api_type = "qianfan" |
|
erniebot.ak = args.api_key |
|
erniebot.sk = args.secret_key |
|
|
|
|
|
faiss_document_store = "faiss_document_store.db" |
|
|
|
if os.path.exists(args.index_name) and os.path.exists(faiss_document_store): |
|
document_store = FAISSDocumentStore.load(args.index_name) |
|
retriever = EmbeddingRetriever( |
|
document_store=document_store, |
|
retriever_batch_size=16, |
|
api_key=args.api_key, |
|
secret_key=args.secret_key, |
|
) |
|
|
|
else: |
|
if os.path.exists(args.index_name): |
|
os.remove(args.index_name) |
|
if os.path.exists(faiss_document_store): |
|
os.remove(faiss_document_store) |
|
document_store = FAISSDocumentStore( |
|
embedding_dim=384, |
|
duplicate_documents="skip", |
|
return_embedding=True, |
|
faiss_index_factory_str="Flat", |
|
) |
|
retriever = EmbeddingRetriever( |
|
document_store=document_store, |
|
retriever_batch_size=16, |
|
api_key=args.api_key, |
|
secret_key=args.secret_key, |
|
) |
|
|
|
pdf_converter = PDFToTextConverter() |
|
|
|
text_splitter = SpacyTextSplitter(separator="\n", chunk_size=384, chunk_overlap=128, filters=["\n"]) |
|
indexing_pipeline = Pipeline() |
|
indexing_pipeline.add_node(component=pdf_converter, name="pdf_converter", inputs=["File"]) |
|
indexing_pipeline.add_node(component=text_splitter, name="Splitter", inputs=["pdf_converter"]) |
|
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["Splitter"]) |
|
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"]) |
|
files_paths = glob.glob(args.file_paths + "/*.pdf") |
|
indexing_pipeline.run(file_paths=files_paths) |
|
document_store.save(args.index_name) |
|
|
|
|
|
ernie_bot = ErnieBot(api_key=args.api_key, secret_key=args.secret_key) |
|
query_pipeline = Pipeline() |
|
query_pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) |
|
|
|
|
|
functions = [ |
|
{ |
|
"name": "search_knowledge_base", |
|
"description": "在住房和城乡建设部规章中寻找和query最相关的片段", |
|
"parameters": { |
|
"type": "object", |
|
"properties": {"query": {"type": "string", "description": "规章查询语句"}}, |
|
"required": ["query"], |
|
}, |
|
"responses": { |
|
"type": "object", |
|
"description": "检索结果,内容为住房和城乡建设部规章中和query相关的文字片段", |
|
"properties": { |
|
"documents": { |
|
"type": "array", |
|
"items": { |
|
"type": "object", |
|
"properties": { |
|
"document": {"type": "string", "description": "和query相关的文字片段"}, |
|
}, |
|
}, |
|
} |
|
}, |
|
"required": ["documents"], |
|
}, |
|
"examples": [ |
|
{"role": "user", "content": "企业申请建筑业企业资质需要哪些材料?"}, |
|
{ |
|
"role": "assistant", |
|
"content": None, |
|
"function_call": { |
|
"name": "search_knowledge_base", |
|
"thoughts": "这是和城市建设法规标准相关的问题,我需要查询住房和城乡建设部规章,并且设置query为'企业申请建筑业企业资质需要的材料'", |
|
"arguments": '{ "query": "企业申请建筑业企业资质需要哪些材料?"}', |
|
}, |
|
}, |
|
{"role": "user", "content": "历史文化街区的城市设计有什么要求?"}, |
|
{ |
|
"role": "assistant", |
|
"content": None, |
|
"function_call": { |
|
"name": "search_knowledge_base", |
|
"thoughts": "这是和城市建设法规标准相关的问题,我需要查询住房和城乡建设部规章,并且设置query为'历史文化街区的设计要求'", |
|
"arguments": '{ "query": "历史文化街区的设计要求"}', |
|
}, |
|
}, |
|
], |
|
} |
|
] |
|
|
|
|
|
def search_knowledge_base(query): |
|
prediction = query_pipeline.run( |
|
query=query, |
|
params={ |
|
"Retriever": { |
|
"top_k": args.retriever_top_k, |
|
}, |
|
}, |
|
) |
|
documents = [{"document": doc.content} for doc in prediction["documents"]] |
|
return {"documents": documents} |
|
|
|
|
|
def history_transform(history=[]): |
|
messages = [] |
|
if len(history) < 2: |
|
return messages |
|
|
|
for turn_idx in range(1, len(history)): |
|
messages.extend( |
|
[ |
|
{"role": "user", "content": history[turn_idx][0]}, |
|
{"role": "assistant", "content": history[turn_idx][1]}, |
|
] |
|
) |
|
return messages |
|
|
|
|
|
def add_message_chatbot(messages, history): |
|
history.append([messages, None]) |
|
return None, history |
|
|
|
|
|
def prediction(history): |
|
logs = [] |
|
query = history.pop()[0] |
|
if query == "": |
|
return history, "注意:问题不能为空" |
|
|
|
|
|
for turn_idx in range(len(history)): |
|
if history[turn_idx][0] is not None: |
|
history[turn_idx][0] = history[turn_idx][0].replace("<br>", "") |
|
if history[turn_idx][1] is not None: |
|
history[turn_idx][1] = history[turn_idx][1].replace("<br>", "") |
|
|
|
|
|
messages = history_transform(history) |
|
|
|
messages.append({"role": "user", "content": query}) |
|
|
|
response = erniebot.ChatCompletion.create( |
|
model="ernie-bot", |
|
messages=messages, |
|
functions=functions, |
|
) |
|
|
|
if "function_call" not in response: |
|
logs.append({"function_call结果": "未触发"}) |
|
result = response["result"] |
|
|
|
else: |
|
function_call = response.function_call |
|
logs.append({"function_call结果": function_call}) |
|
|
|
func_args = json.loads(function_call["arguments"]) |
|
|
|
res = search_knowledge_base(**func_args) |
|
logs.append({"检索结果": res}) |
|
|
|
messages.append({"role": "assistant", "content": None, "function_call": function_call}) |
|
messages.append( |
|
{ |
|
"role": "function", |
|
"name": function_call["name"], |
|
"content": json.dumps(res, ensure_ascii=False), |
|
} |
|
) |
|
response = erniebot.ChatCompletion.create(model="ernie-bot", messages=messages) |
|
result = response["result"] |
|
history.append([query, result]) |
|
return history, logs |
|
|
|
|
|
def launch_ui(): |
|
with gr.Blocks(title="ERNIE Bot 城市建设法规标准小助手", theme=gr.themes.Base()) as demo: |
|
gr.HTML("""<h1 align="center">ERNIE Bot 城市建设法规标准小助手</h1>""") |
|
with gr.Column(): |
|
chatbot = gr.Chatbot( |
|
value=[[None, "您好, 我是 ERNIE Bot 城市建设法规标准小助手。除了普通的大模型能力以外,还特别了解住房和城乡建设部规章哦"]], |
|
scale=35, |
|
height=500, |
|
) |
|
message = gr.Textbox(placeholder="哪些建筑企业资质需要国务院住房城乡建设主管部门许可?", lines=1, max_lines=20) |
|
with gr.Row(): |
|
submit = gr.Button("🚀 提交", variant="primary", scale=1) |
|
clear = gr.Button("清除", variant="primary", scale=1) |
|
log = gr.JSON() |
|
message.submit(add_message_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then( |
|
prediction, inputs=[chatbot], outputs=[chatbot, log] |
|
) |
|
submit.click(add_message_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then( |
|
prediction, inputs=[chatbot], outputs=[chatbot, log] |
|
) |
|
clear.click( |
|
lambda _: ([[None, "您好, 我是 ERNIE Bot 城市建设法规标准小助手。除了普通的大模型能力以外,还特别了解住房和城乡建设部规章哦"]]), |
|
inputs=[clear], |
|
outputs=[chatbot], |
|
) |
|
demo.launch(server_name=args.host, server_port=args.port, debug=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
launch_ui() |
|
|