File size: 10,321 Bytes
569cdb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import argparse
import glob
import json
import os
import gradio as gr
from pipelines.document_stores import FAISSDocumentStore
from pipelines.nodes import (
EmbeddingRetriever,
ErnieBot,
PDFToTextConverter,
SpacyTextSplitter,
)
from pipelines.pipelines import Pipeline
import erniebot
parser = argparse.ArgumentParser()
parser.add_argument(
"--index_name", default="construct_demo_index", type=str, help="The ann index name of ANN."
)
parser.add_argument(
"--file_paths", default="./construction_regulations", type=str, help="The PDF file path."
)
parser.add_argument("--retriever_top_k", default=5, type=int, help="Number of recall items for search")
parser.add_argument(
"--chunk_size", default=384, type=int, help="The length of data for indexing by retriever"
)
parser.add_argument("--host", type=str, default="localhost", help="host ip of ANN search engine")
parser.add_argument("--port", type=int, default=8081, help="host ip of ANN search engine")
parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
parser.add_argument("--secret_key", default=None, type=str, help="The secret key.")
args = parser.parse_args()
erniebot.api_type = "qianfan"
erniebot.ak = args.api_key
erniebot.sk = args.secret_key
# 利用Paddle-Pipelines构建本地语义检索服务
faiss_document_store = "faiss_document_store.db"
# 如本地语义检索索引已经存在,则直接读取已经构建好的索引
if os.path.exists(args.index_name) and os.path.exists(faiss_document_store):
document_store = FAISSDocumentStore.load(args.index_name)
retriever = EmbeddingRetriever(
document_store=document_store,
retriever_batch_size=16, # 16 is the max batch size allowed by ErnieBot Embedding
api_key=args.api_key,
secret_key=args.secret_key,
)
# 如本地语义检索索引不存在,则构建新的索引,并且保存在本地
else:
if os.path.exists(args.index_name):
os.remove(args.index_name)
if os.path.exists(faiss_document_store):
os.remove(faiss_document_store)
document_store = FAISSDocumentStore(
embedding_dim=384, # hardcode the embedding dim to 384 for ErnieBot Embedding
duplicate_documents="skip",
return_embedding=True,
faiss_index_factory_str="Flat",
)
retriever = EmbeddingRetriever(
document_store=document_store,
retriever_batch_size=16, # 16 is the max batch size allowed by ErnieBot Embedding
api_key=args.api_key,
secret_key=args.secret_key,
)
# 将PDF文档转换为文字
pdf_converter = PDFToTextConverter()
# 将文字分片
text_splitter = SpacyTextSplitter(separator="\n", chunk_size=384, chunk_overlap=128, filters=["\n"])
indexing_pipeline = Pipeline()
indexing_pipeline.add_node(component=pdf_converter, name="pdf_converter", inputs=["File"])
indexing_pipeline.add_node(component=text_splitter, name="Splitter", inputs=["pdf_converter"])
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["Splitter"])
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"])
files_paths = glob.glob(args.file_paths + "/*.pdf")
indexing_pipeline.run(file_paths=files_paths)
document_store.save(args.index_name)
# 构建用于检索的Pipeline
ernie_bot = ErnieBot(api_key=args.api_key, secret_key=args.secret_key)
query_pipeline = Pipeline()
query_pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
# 定义函数api, 共1个api, 以及2个使用例子
functions = [
{
"name": "search_knowledge_base",
"description": "在住房和城乡建设部规章中寻找和query最相关的片段",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string", "description": "规章查询语句"}},
"required": ["query"],
},
"responses": {
"type": "object",
"description": "检索结果,内容为住房和城乡建设部规章中和query相关的文字片段",
"properties": {
"documents": {
"type": "array",
"items": {
"type": "object",
"properties": {
"document": {"type": "string", "description": "和query相关的文字片段"},
},
},
}
},
"required": ["documents"],
},
"examples": [
{"role": "user", "content": "企业申请建筑业企业资质需要哪些材料?"},
{
"role": "assistant",
"content": None,
"function_call": {
"name": "search_knowledge_base",
"thoughts": "这是和城市建设法规标准相关的问题,我需要查询住房和城乡建设部规章,并且设置query为'企业申请建筑业企业资质需要的材料'",
"arguments": '{ "query": "企业申请建筑业企业资质需要哪些材料?"}',
},
},
{"role": "user", "content": "历史文化街区的城市设计有什么要求?"},
{
"role": "assistant",
"content": None,
"function_call": {
"name": "search_knowledge_base",
"thoughts": "这是和城市建设法规标准相关的问题,我需要查询住房和城乡建设部规章,并且设置query为'历史文化街区的设计要求'",
"arguments": '{ "query": "历史文化街区的设计要求"}',
},
},
],
}
]
def search_knowledge_base(query):
prediction = query_pipeline.run(
query=query,
params={
"Retriever": {
"top_k": args.retriever_top_k,
},
},
)
documents = [{"document": doc.content} for doc in prediction["documents"]]
return {"documents": documents}
def history_transform(history=[]):
messages = []
if len(history) < 2:
return messages
for turn_idx in range(1, len(history)):
messages.extend(
[
{"role": "user", "content": history[turn_idx][0]},
{"role": "assistant", "content": history[turn_idx][1]},
]
)
return messages
def add_message_chatbot(messages, history):
history.append([messages, None])
return None, history
def prediction(history):
logs = []
query = history.pop()[0]
if query == "":
return history, "注意:问题不能为空"
# 消除潜在的错误格式问题
for turn_idx in range(len(history)):
if history[turn_idx][0] is not None:
history[turn_idx][0] = history[turn_idx][0].replace("<br>", "")
if history[turn_idx][1] is not None:
history[turn_idx][1] = history[turn_idx][1].replace("<br>", "")
# 将对话历史从gradio格式转化为 eb sdk的格式
messages = history_transform(history)
# 插入将当前轮次的用户query插入上下文当中
messages.append({"role": "user", "content": query})
# 调用eb的chat completion, 提供functions入参
response = erniebot.ChatCompletion.create(
model="ernie-bot",
messages=messages,
functions=functions,
)
# 如果function call未触发,模型直接回答,则直接返回模型结果
if "function_call" not in response:
logs.append({"function_call结果": "未触发"})
result = response["result"]
# 如果function call触发
else:
function_call = response.function_call
logs.append({"function_call结果": function_call})
# 解析模型返回的function call入参
func_args = json.loads(function_call["arguments"])
# 调用function
res = search_knowledge_base(**func_args)
logs.append({"检索结果": res})
# 使用 eb的chat completion 对于function返回的结果进行润色, 用结果回复用户
messages.append({"role": "assistant", "content": None, "function_call": function_call})
messages.append(
{
"role": "function",
"name": function_call["name"],
"content": json.dumps(res, ensure_ascii=False),
}
)
response = erniebot.ChatCompletion.create(model="ernie-bot", messages=messages)
result = response["result"]
history.append([query, result])
return history, logs
def launch_ui():
with gr.Blocks(title="ERNIE Bot 城市建设法规标准小助手", theme=gr.themes.Base()) as demo:
gr.HTML("""<h1 align="center">ERNIE Bot 城市建设法规标准小助手</h1>""")
with gr.Column():
chatbot = gr.Chatbot(
value=[[None, "您好, 我是 ERNIE Bot 城市建设法规标准小助手。除了普通的大模型能力以外,还特别了解住房和城乡建设部规章哦"]],
scale=35,
height=500,
)
message = gr.Textbox(placeholder="哪些建筑企业资质需要国务院住房城乡建设主管部门许可?", lines=1, max_lines=20)
with gr.Row():
submit = gr.Button("🚀 提交", variant="primary", scale=1)
clear = gr.Button("清除", variant="primary", scale=1)
log = gr.JSON()
message.submit(add_message_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then(
prediction, inputs=[chatbot], outputs=[chatbot, log]
)
submit.click(add_message_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then(
prediction, inputs=[chatbot], outputs=[chatbot, log]
)
clear.click(
lambda _: ([[None, "您好, 我是 ERNIE Bot 城市建设法规标准小助手。除了普通的大模型能力以外,还特别了解住房和城乡建设部规章哦"]]),
inputs=[clear],
outputs=[chatbot],
)
demo.launch(server_name=args.host, server_port=args.port, debug=True)
if __name__ == "__main__":
launch_ui()
|