|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from flask import request |
|
from flask_login import login_required, current_user |
|
from api.db.services.dialog_service import DialogService |
|
from api.db import StatusEnum |
|
from api.db.services.knowledgebase_service import KnowledgebaseService |
|
from api.db.services.user_service import TenantService, UserTenantService |
|
from api.settings import RetCode |
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request |
|
from api.utils import get_uuid |
|
from api.utils.api_utils import get_json_result |
|
|
|
|
|
@manager.route('/set', methods=['POST']) |
|
@login_required |
|
def set_dialog(): |
|
req = request.json |
|
dialog_id = req.get("dialog_id") |
|
name = req.get("name", "New Dialog") |
|
description = req.get("description", "A helpful Dialog") |
|
icon = req.get("icon", "") |
|
top_n = req.get("top_n", 6) |
|
top_k = req.get("top_k", 1024) |
|
rerank_id = req.get("rerank_id", "") |
|
if not rerank_id: req["rerank_id"] = "" |
|
similarity_threshold = req.get("similarity_threshold", 0.1) |
|
vector_similarity_weight = req.get("vector_similarity_weight", 0.3) |
|
if vector_similarity_weight is None: vector_similarity_weight = 0.3 |
|
llm_setting = req.get("llm_setting", {}) |
|
default_prompt = { |
|
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 |
|
以下是知识库: |
|
{knowledge} |
|
以上是知识库。""", |
|
"prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", |
|
"parameters": [ |
|
{"key": "knowledge", "optional": False} |
|
], |
|
"empty_response": "Sorry! 知识库中未找到相关内容!" |
|
} |
|
prompt_config = req.get("prompt_config", default_prompt) |
|
|
|
if not prompt_config["system"]: |
|
prompt_config["system"] = default_prompt["system"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
for p in prompt_config["parameters"]: |
|
if p["optional"]: |
|
continue |
|
if prompt_config["system"].find("{%s}" % p["key"]) < 0: |
|
return get_data_error_result( |
|
retmsg="Parameter '{}' is not used".format(p["key"])) |
|
|
|
try: |
|
e, tenant = TenantService.get_by_id(current_user.id) |
|
if not e: |
|
return get_data_error_result(retmsg="Tenant not found!") |
|
llm_id = req.get("llm_id", tenant.llm_id) |
|
if not dialog_id: |
|
if not req.get("kb_ids"): |
|
return get_data_error_result( |
|
retmsg="Fail! Please select knowledgebase!") |
|
dia = { |
|
"id": get_uuid(), |
|
"tenant_id": current_user.id, |
|
"name": name, |
|
"kb_ids": req["kb_ids"], |
|
"description": description, |
|
"llm_id": llm_id, |
|
"llm_setting": llm_setting, |
|
"prompt_config": prompt_config, |
|
"top_n": top_n, |
|
"top_k": top_k, |
|
"rerank_id": rerank_id, |
|
"similarity_threshold": similarity_threshold, |
|
"vector_similarity_weight": vector_similarity_weight, |
|
"icon": icon |
|
} |
|
if not DialogService.save(**dia): |
|
return get_data_error_result(retmsg="Fail to new a dialog!") |
|
e, dia = DialogService.get_by_id(dia["id"]) |
|
if not e: |
|
return get_data_error_result(retmsg="Fail to new a dialog!") |
|
return get_json_result(data=dia.to_json()) |
|
else: |
|
del req["dialog_id"] |
|
if "kb_names" in req: |
|
del req["kb_names"] |
|
if not DialogService.update_by_id(dialog_id, req): |
|
return get_data_error_result(retmsg="Dialog not found!") |
|
e, dia = DialogService.get_by_id(dialog_id) |
|
if not e: |
|
return get_data_error_result(retmsg="Fail to update a dialog!") |
|
dia = dia.to_dict() |
|
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) |
|
return get_json_result(data=dia) |
|
except Exception as e: |
|
return server_error_response(e) |
|
|
|
|
|
@manager.route('/get', methods=['GET']) |
|
@login_required |
|
def get(): |
|
dialog_id = request.args["dialog_id"] |
|
try: |
|
e, dia = DialogService.get_by_id(dialog_id) |
|
if not e: |
|
return get_data_error_result(retmsg="Dialog not found!") |
|
dia = dia.to_dict() |
|
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) |
|
return get_json_result(data=dia) |
|
except Exception as e: |
|
return server_error_response(e) |
|
|
|
|
|
def get_kb_names(kb_ids): |
|
ids, nms = [], [] |
|
for kid in kb_ids: |
|
e, kb = KnowledgebaseService.get_by_id(kid) |
|
if not e or kb.status != StatusEnum.VALID.value: |
|
continue |
|
ids.append(kid) |
|
nms.append(kb.name) |
|
return ids, nms |
|
|
|
|
|
@manager.route('/list', methods=['GET']) |
|
@login_required |
|
def list_dialogs(): |
|
try: |
|
diags = DialogService.query( |
|
tenant_id=current_user.id, |
|
status=StatusEnum.VALID.value, |
|
reverse=True, |
|
order_by=DialogService.model.create_time) |
|
diags = [d.to_dict() for d in diags] |
|
for d in diags: |
|
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) |
|
return get_json_result(data=diags) |
|
except Exception as e: |
|
return server_error_response(e) |
|
|
|
|
|
@manager.route('/rm', methods=['POST']) |
|
@login_required |
|
@validate_request("dialog_ids") |
|
def rm(): |
|
req = request.json |
|
dialog_list=[] |
|
tenants = UserTenantService.query(user_id=current_user.id) |
|
try: |
|
for id in req["dialog_ids"]: |
|
for tenant in tenants: |
|
if DialogService.query(tenant_id=tenant.tenant_id, id=id): |
|
break |
|
else: |
|
return get_json_result( |
|
data=False, retmsg=f'Only owner of dialog authorized for this operation.', |
|
retcode=RetCode.OPERATING_ERROR) |
|
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) |
|
DialogService.update_many_by_id(dialog_list) |
|
return get_json_result(data=True) |
|
except Exception as e: |
|
return server_error_response(e) |
|
|