ragflow / graphrag /search.py
Kevin Hu
Rebuild graph when it's out of time. (#4607)
62a5517
raw
history blame
14.3 kB
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
from api.utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS
from graphrag.utils import get_entity_type2sampels, get_llm_cache, set_llm_cache, get_relation
from rag.utils import num_tokens_from_string
from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
class KGSearch(Dealer):
def _chat(self, llm_bdl, system, history, gen_conf):
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
if response:
return response
response = llm_bdl.chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
return response
def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = get_entity_type2sampels(idxnms, kb_ids)
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5})
try:
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
except json_repair.JSONDecodeError:
try:
result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
# Handle parsing error
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "entity_kwd", "rank_flt",
"n_hop_with_weight"])
for _, ent in es_res.items():
if float(ent.get("_score", 0)) < sim_thr:
continue
if isinstance(ent["entity_kwd"], list):
ent["entity_kwd"] = ent["entity_kwd"][0]
res[ent["entity_kwd"]] = {
"sim": float(ent.get("_score", 0)),
"pagerank": float(ent.get("rank_flt", 0)),
"n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")),
"description": ent.get("content_with_weight", "{}")
}
return res
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if float(ent["_score"]) < sim_thr:
continue
f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
if isinstance(f, list):
f = f[0]
if isinstance(t, list):
t = t[0]
res[(f, t)] = {
"sim": float(ent["_score"]),
"pagerank": float(ent.get("weight_int", 0)),
"description": ent["content_with_weight"]
}
return res
def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not keywords:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt"], [], filters, [matchDense],
OrderByExpr(), 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, sim_thr)
def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not txt:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "relation"
matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(
["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
[], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
return self._relation_info_from_(es_res, sim_thr)
def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
if not types:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
filters["entity_type_kwd"] = types
ordr = OrderByExpr()
ordr.desc("rank_flt")
es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, 0)
def retrieval(self, question: str,
tenant_ids: str | list[str],
kb_ids: list[str],
emb_mdl,
llm,
max_token: int = 8196,
ent_topn: int = 6,
rel_topn: int = 6,
comm_topn: int = 1,
ent_sim_threshold: float = 0.3,
rel_sim_threshold: float = 0.3,
):
qst = question
filters = self.get_filters({"kb_ids": kb_ids})
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
idxnms = [index_name(tid) for tid in tenant_ids]
ty_kwds = []
ents = []
try:
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
except Exception as e:
logging.exception(e)
ents = [qst]
pass
ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
nhop_pathes = defaultdict(dict)
for _, ent in ents_from_query.items():
nhops = ent.get("n_hop_ents", [])
for nbr in nhops:
path = nbr["path"]
wts = nbr["weights"]
for i in range(len(path) - 1):
f, t = path[i], path[i + 1]
if (f, t) in nhop_pathes:
nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
else:
nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
nhop_pathes[(f, t)]["pagerank"] = wts[i]
logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
# P(E|Q) => P(E) * P(Q|E) => pagerank * sim
for ent in ents_from_types.keys():
if ent not in ents_from_query:
continue
ents_from_query[ent]["sim"] *= 2
for (f, t) in rels_from_txt.keys():
pair = tuple(sorted([f, t]))
s = 0
if pair in nhop_pathes:
s += nhop_pathes[pair]["sim"]
del nhop_pathes[pair]
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)]["sim"] *= s + 1
# This is for the relations from n-hop but not by query search
for (f, t) in nhop_pathes.keys():
s = 0
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)] = {
"sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
"pagerank": nhop_pathes[(f, t)]["pagerank"]
}
ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:ent_topn]
rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:rel_topn]
ents = []
relas = []
for n, ent in ents_from_query:
ents.append({
"Entity": n,
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
"Description": json.loads(ent["description"]).get("description", "")
})
max_token -= num_tokens_from_string(str(ents[-1]))
if max_token <= 0:
ents = ents[:-1]
break
for (f, t), rel in rels_from_txt:
if not rel.get("description"):
for tid in tenant_ids:
rela = get_relation(tid, kb_ids, f, t)
if rela:
break
else:
continue
rel["description"] = rela["description"]
relas.append({
"From Entity": f,
"To Entity": t,
"Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
"Description": json.loads(ent["description"]).get("description", "")
})
max_token -= num_tokens_from_string(str(relas[-1]))
if max_token <= 0:
relas = relas[:-1]
break
if ents:
ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv())
else:
ents = ""
if relas:
relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv())
else:
relas = ""
return {
"chunk_id": get_uuid(),
"content_ltks": "",
"content_with_weight": ents + relas + self._community_retrival_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
comm_topn, max_token),
"doc_id": "",
"docnm_kwd": "Related content in Knowledge Graph",
"kb_id": kb_ids,
"important_kwd": [],
"image_id": "",
"similarity": 1.,
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
}
def _community_retrival_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval
fields = ["docnm_kwd", "content_with_weight"]
odr = OrderByExpr()
odr.desc("weight_flt")
fltr = deepcopy(condition)
fltr["knowledge_graph_kwd"] = "community_report"
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"])
txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
max_token -= num_tokens_from_string(str(txts[-1]))
if not txts:
return ""
return "\n---- Community Report ----\n" + "\n".join(txts)
if __name__ == "__main__":
from api import settings
import argparse
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from rag.nlp import search
settings.init_settings()
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True)
parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True)
args = parser.parse_args()
kb_id = args.kb_id
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(settings.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))