|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from flask import request
|
|
|
|
from api.db import StatusEnum
|
|
from api.db.db_models import TenantLLM
|
|
from api.db.services.dialog_service import DialogService
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.services.llm_service import LLMService, TenantLLMService
|
|
from api.db.services.user_service import TenantService
|
|
from api.settings import RetCode
|
|
from api.utils import get_uuid
|
|
from api.utils.api_utils import get_data_error_result, token_required
|
|
from api.utils.api_utils import get_json_result
|
|
|
|
|
|
@manager.route('/save', methods=['POST'])
|
|
@token_required
|
|
def save(tenant_id):
|
|
req = request.json
|
|
|
|
if req.get("knowledgebases") == []:
|
|
return get_data_error_result(retmsg="knowledgebases can not be empty list")
|
|
kb_list = []
|
|
if req.get("knowledgebases"):
|
|
for kb in req.get("knowledgebases"):
|
|
if not kb["id"]:
|
|
return get_data_error_result(retmsg="knowledgebase needs id")
|
|
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
|
|
return get_data_error_result(retmsg="you do not own the knowledgebase")
|
|
|
|
|
|
kb_list.append(kb["id"])
|
|
req["kb_ids"] = kb_list
|
|
|
|
llm = req.get("llm")
|
|
if llm:
|
|
if "model_name" in llm:
|
|
req["llm_id"] = llm.pop("model_name")
|
|
req["llm_setting"] = req.pop("llm")
|
|
e, tenant = TenantService.get_by_id(tenant_id)
|
|
if not e:
|
|
return get_data_error_result(retmsg="Tenant not found!")
|
|
|
|
prompt = req.get("prompt")
|
|
key_mapping = {"parameters": "variables",
|
|
"prologue": "opener",
|
|
"quote": "show_quote",
|
|
"system": "prompt",
|
|
"rerank_id": "rerank_model",
|
|
"vector_similarity_weight": "keywords_similarity_weight"}
|
|
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
|
|
if prompt:
|
|
for new_key, old_key in key_mapping.items():
|
|
if old_key in prompt:
|
|
prompt[new_key] = prompt.pop(old_key)
|
|
for key in key_list:
|
|
if key in prompt:
|
|
req[key] = prompt.pop(key)
|
|
req["prompt_config"] = req.pop("prompt")
|
|
|
|
if "id" not in req:
|
|
|
|
if not kb_list:
|
|
return get_data_error_result(retmsg="knowledgebases are required!")
|
|
|
|
req["id"] = get_uuid()
|
|
req["description"] = req.get("description", "A helpful Assistant")
|
|
req["icon"] = req.get("avatar", "")
|
|
req["top_n"] = req.get("top_n", 6)
|
|
req["top_k"] = req.get("top_k", 1024)
|
|
req["rerank_id"] = req.get("rerank_id", "")
|
|
if req.get("llm_id"):
|
|
if not TenantLLMService.query(llm_name=req["llm_id"]):
|
|
return get_data_error_result(retmsg="the model_name does not exist.")
|
|
else:
|
|
req["llm_id"] = tenant.llm_id
|
|
if not req.get("name"):
|
|
return get_data_error_result(retmsg="name is required.")
|
|
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
|
return get_data_error_result(retmsg="Duplicated assistant name in creating dataset.")
|
|
|
|
if req.get("tenant_id"):
|
|
return get_data_error_result(retmsg="tenant_id must not be provided.")
|
|
req["tenant_id"] = tenant_id
|
|
|
|
default_prompt = {
|
|
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
|
|
以下是知识库:
|
|
{knowledge}
|
|
以上是知识库。""",
|
|
"prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
|
"parameters": [
|
|
{"key": "knowledge", "optional": False}
|
|
],
|
|
"empty_response": "Sorry! 知识库中未找到相关内容!"
|
|
}
|
|
key_list_2 = ["system", "prologue", "parameters", "empty_response"]
|
|
if "prompt_config" not in req:
|
|
req['prompt_config'] = {}
|
|
for key in key_list_2:
|
|
temp = req['prompt_config'].get(key)
|
|
if not temp:
|
|
req['prompt_config'][key] = default_prompt[key]
|
|
for p in req['prompt_config']["parameters"]:
|
|
if p["optional"]:
|
|
continue
|
|
if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0:
|
|
return get_data_error_result(
|
|
retmsg="Parameter '{}' is not used".format(p["key"]))
|
|
|
|
if not DialogService.save(**req):
|
|
return get_data_error_result(retmsg="Fail to new an assistant!")
|
|
|
|
e, res = DialogService.get_by_id(req["id"])
|
|
if not e:
|
|
return get_data_error_result(retmsg="Fail to new an assistant!")
|
|
res = res.to_json()
|
|
renamed_dict = {}
|
|
for key, value in res["prompt_config"].items():
|
|
new_key = key_mapping.get(key, key)
|
|
renamed_dict[new_key] = value
|
|
res["prompt"] = renamed_dict
|
|
del res["prompt_config"]
|
|
new_dict = {"similarity_threshold": res["similarity_threshold"],
|
|
"keywords_similarity_weight": res["vector_similarity_weight"],
|
|
"top_n": res["top_n"],
|
|
"rerank_model": res['rerank_id']}
|
|
res["prompt"].update(new_dict)
|
|
for key in key_list:
|
|
del res[key]
|
|
res["llm"] = res.pop("llm_setting")
|
|
res["llm"]["model_name"] = res.pop("llm_id")
|
|
del res["kb_ids"]
|
|
res["knowledgebases"] = req["knowledgebases"]
|
|
res["avatar"] = res.pop("icon")
|
|
return get_json_result(data=res)
|
|
else:
|
|
|
|
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
|
|
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
|
|
|
|
if not req["id"]:
|
|
return get_data_error_result(retmsg="id can not be empty")
|
|
e, res = DialogService.get_by_id(req["id"])
|
|
res = res.to_json()
|
|
if "llm_id" in req:
|
|
if not TenantLLMService.query(llm_name=req["llm_id"]):
|
|
return get_data_error_result(retmsg="the model_name does not exist.")
|
|
if "name" in req:
|
|
if not req.get("name"):
|
|
return get_data_error_result(retmsg="name is not empty.")
|
|
if req["name"].lower() != res["name"].lower() \
|
|
and len(
|
|
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
|
return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.")
|
|
if "prompt_config" in req:
|
|
res["prompt_config"].update(req["prompt_config"])
|
|
for p in res["prompt_config"]["parameters"]:
|
|
if p["optional"]:
|
|
continue
|
|
if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
|
|
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
|
|
if "llm_setting" in req:
|
|
res["llm_setting"].update(req["llm_setting"])
|
|
req["prompt_config"] = res["prompt_config"]
|
|
req["llm_setting"] = res["llm_setting"]
|
|
|
|
if "avatar" in req:
|
|
req["icon"] = req.pop("avatar")
|
|
assistant_id = req.pop("id")
|
|
if "knowledgebases" in req:
|
|
req.pop("knowledgebases")
|
|
if not DialogService.update_by_id(assistant_id, req):
|
|
return get_data_error_result(retmsg="Assistant not found!")
|
|
return get_json_result(data=True)
|
|
|
|
|
|
@manager.route('/delete', methods=['DELETE'])
|
|
@token_required
|
|
def delete(tenant_id):
|
|
req = request.args
|
|
if "id" not in req:
|
|
return get_data_error_result(retmsg="id is required")
|
|
id = req['id']
|
|
if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
|
|
return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
|
|
|
temp_dict = {"status": StatusEnum.INVALID.value}
|
|
DialogService.update_by_id(req["id"], temp_dict)
|
|
return get_json_result(data=True)
|
|
|
|
|
|
@manager.route('/get', methods=['GET'])
|
|
@token_required
|
|
def get(tenant_id):
|
|
req = request.args
|
|
if "id" in req:
|
|
id = req["id"]
|
|
ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value)
|
|
if not ass:
|
|
return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
|
if "name" in req:
|
|
name = req["name"]
|
|
if ass[0].name != name:
|
|
return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR)
|
|
res = ass[0].to_json()
|
|
else:
|
|
if "name" in req:
|
|
name = req["name"]
|
|
ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
|
if not ass:
|
|
return get_json_result(data=False, retmsg='You do not own the assistant.',
|
|
retcode=RetCode.OPERATING_ERROR)
|
|
res = ass[0].to_json()
|
|
else:
|
|
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")
|
|
renamed_dict = {}
|
|
key_mapping = {"parameters": "variables",
|
|
"prologue": "opener",
|
|
"quote": "show_quote",
|
|
"system": "prompt",
|
|
"rerank_id": "rerank_model",
|
|
"vector_similarity_weight": "keywords_similarity_weight"}
|
|
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
|
|
for key, value in res["prompt_config"].items():
|
|
new_key = key_mapping.get(key, key)
|
|
renamed_dict[new_key] = value
|
|
res["prompt"] = renamed_dict
|
|
del res["prompt_config"]
|
|
new_dict = {"similarity_threshold": res["similarity_threshold"],
|
|
"keywords_similarity_weight": res["vector_similarity_weight"],
|
|
"top_n": res["top_n"],
|
|
"rerank_model": res['rerank_id']}
|
|
res["prompt"].update(new_dict)
|
|
for key in key_list:
|
|
del res[key]
|
|
res["llm"] = res.pop("llm_setting")
|
|
res["llm"]["model_name"] = res.pop("llm_id")
|
|
kb_list = []
|
|
for kb_id in res["kb_ids"]:
|
|
kb = KnowledgebaseService.query(id=kb_id)
|
|
kb_list.append(kb[0].to_json())
|
|
del res["kb_ids"]
|
|
res["knowledgebases"] = kb_list
|
|
res["avatar"] = res.pop("icon")
|
|
return get_json_result(data=res)
|
|
|
|
|
|
@manager.route('/list', methods=['GET'])
|
|
@token_required
|
|
def list_assistants(tenant_id):
|
|
assts = DialogService.query(
|
|
tenant_id=tenant_id,
|
|
status=StatusEnum.VALID.value,
|
|
reverse=True,
|
|
order_by=DialogService.model.create_time)
|
|
assts = [d.to_dict() for d in assts]
|
|
list_assts = []
|
|
renamed_dict = {}
|
|
key_mapping = {"parameters": "variables",
|
|
"prologue": "opener",
|
|
"quote": "show_quote",
|
|
"system": "prompt",
|
|
"rerank_id": "rerank_model",
|
|
"vector_similarity_weight": "keywords_similarity_weight"}
|
|
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
|
|
for res in assts:
|
|
for key, value in res["prompt_config"].items():
|
|
new_key = key_mapping.get(key, key)
|
|
renamed_dict[new_key] = value
|
|
res["prompt"] = renamed_dict
|
|
del res["prompt_config"]
|
|
new_dict = {"similarity_threshold": res["similarity_threshold"],
|
|
"keywords_similarity_weight": res["vector_similarity_weight"],
|
|
"top_n": res["top_n"],
|
|
"rerank_model": res['rerank_id']}
|
|
res["prompt"].update(new_dict)
|
|
for key in key_list:
|
|
del res[key]
|
|
res["llm"] = res.pop("llm_setting")
|
|
res["llm"]["model_name"] = res.pop("llm_id")
|
|
kb_list = []
|
|
for kb_id in res["kb_ids"]:
|
|
kb = KnowledgebaseService.query(id=kb_id)
|
|
kb_list.append(kb[0].to_json())
|
|
del res["kb_ids"]
|
|
res["knowledgebases"] = kb_list
|
|
res["avatar"] = res.pop("icon")
|
|
list_assts.append(res)
|
|
return get_json_result(data=list_assts)
|
|
|