File size: 2,967 Bytes
44731b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebde808
 
44731b3
6c8312a
44731b3
 
6101699
44731b3
ebde808
 
fe9b6b3
ebde808
 
 
 
 
 
 
 
 
 
 
 
 
 
6101699
ebde808
 
6101699
ebde808
 
 
6101699
ebde808
 
 
 
 
 
 
 
 
6c8312a
 
ebde808
 
 
4fd5400
ebde808
badcb66
ebde808
 
7fcd041
ebde808
 
 
 
 
 
196c662
6101699
ebde808
6101699
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
from flask import request, jsonify

from api.db import LLMType, ParserType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from api.utils.api_utils import validate_request, build_error_result, apikey_required


@manager.route('/dify/retrieval', methods=['POST'])  # noqa: F821
@apikey_required
@validate_request("knowledge_id", "query")
def retrieval(tenant_id):
    req = request.json
    question = req["query"]
    kb_id = req["knowledge_id"]
    retrieval_setting = req.get("retrieval_setting", {})
    similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
    top = int(retrieval_setting.get("top_k", 1024))

    try:

        e, kb = KnowledgebaseService.get_by_id(kb_id)
        if not e:
            return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)

        if kb.tenant_id != tenant_id:
            return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)

        embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

        retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
        ranks = retr.retrieval(
            question,
            embd_mdl,
            kb.tenant_id,
            [kb_id],
            page=1,
            page_size=top,
            similarity_threshold=similarity_threshold,
            vector_similarity_weight=0.3,
            top=top,
            rank_feature=label_question(question, [kb])
        )
        records = []
        for c in ranks["chunks"]:
            c.pop("vector", None)
            records.append({
                "content": c["content_with_weight"],
                "score": c["similarity"],
                "title": c["docnm_kwd"],
                "metadata": {}
            })

        return jsonify({"records": records})
    except Exception as e:
        if str(e).find("not_found") > 0:
            return build_error_result(
                message='No chunk found! Check the chunk status please!',
                code=settings.RetCode.NOT_FOUND
            )
        return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)