|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
import json_repair |
|
import pandas as pd |
|
|
|
from api.utils import get_uuid |
|
from graphrag.query_analyze_prompt import PROMPTS |
|
from graphrag.utils import get_entity_type2sampels, get_llm_cache, set_llm_cache, get_relation |
|
from rag.utils import num_tokens_from_string |
|
from rag.utils.doc_store_conn import OrderByExpr |
|
|
|
from rag.nlp.search import Dealer, index_name |
|
|
|
|
|
class KGSearch(Dealer): |
|
def _chat(self, llm_bdl, system, history, gen_conf): |
|
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf) |
|
if response: |
|
return response |
|
response = llm_bdl.chat(system, history, gen_conf) |
|
if response.find("**ERROR**") >= 0: |
|
raise Exception(response) |
|
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf) |
|
return response |
|
|
|
def query_rewrite(self, llm, question, idxnms, kb_ids): |
|
ty2ents = get_entity_type2sampels(idxnms, kb_ids) |
|
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, |
|
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) |
|
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5}) |
|
try: |
|
keywords_data = json_repair.loads(result) |
|
type_keywords = keywords_data.get("answer_type_keywords", []) |
|
entities_from_query = keywords_data.get("entities_from_query", [])[:5] |
|
return type_keywords, entities_from_query |
|
except json_repair.JSONDecodeError: |
|
try: |
|
result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip() |
|
result = '{' + result.split('{')[1].split('}')[0] + '}' |
|
keywords_data = json_repair.loads(result) |
|
type_keywords = keywords_data.get("answer_type_keywords", []) |
|
entities_from_query = keywords_data.get("entities_from_query", [])[:5] |
|
return type_keywords, entities_from_query |
|
|
|
except Exception as e: |
|
logging.exception(f"JSON parsing error: {result} -> {e}") |
|
raise e |
|
|
|
def _ent_info_from_(self, es_res, sim_thr=0.3): |
|
res = {} |
|
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "entity_kwd", "rank_flt", |
|
"n_hop_with_weight"]) |
|
for _, ent in es_res.items(): |
|
if float(ent.get("_score", 0)) < sim_thr: |
|
continue |
|
if isinstance(ent["entity_kwd"], list): |
|
ent["entity_kwd"] = ent["entity_kwd"][0] |
|
res[ent["entity_kwd"]] = { |
|
"sim": float(ent.get("_score", 0)), |
|
"pagerank": float(ent.get("rank_flt", 0)), |
|
"n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")), |
|
"description": ent.get("content_with_weight", "{}") |
|
} |
|
return res |
|
|
|
def _relation_info_from_(self, es_res, sim_thr=0.3): |
|
res = {} |
|
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", |
|
"weight_int"]) |
|
for _, ent in es_res.items(): |
|
if float(ent["_score"]) < sim_thr: |
|
continue |
|
f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]]) |
|
if isinstance(f, list): |
|
f = f[0] |
|
if isinstance(t, list): |
|
t = t[0] |
|
res[(f, t)] = { |
|
"sim": float(ent["_score"]), |
|
"pagerank": float(ent.get("weight_int", 0)), |
|
"description": ent["content_with_weight"] |
|
} |
|
return res |
|
|
|
def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56): |
|
if not keywords: |
|
return {} |
|
filters = deepcopy(filters) |
|
filters["knowledge_graph_kwd"] = "entity" |
|
matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr) |
|
es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt"], [], filters, [matchDense], |
|
OrderByExpr(), 0, N, |
|
idxnms, kb_ids) |
|
return self._ent_info_from_(es_res, sim_thr) |
|
|
|
def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56): |
|
if not txt: |
|
return {} |
|
filters = deepcopy(filters) |
|
filters["knowledge_graph_kwd"] = "relation" |
|
matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr) |
|
es_res = self.dataStore.search( |
|
["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"], |
|
[], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids) |
|
return self._relation_info_from_(es_res, sim_thr) |
|
|
|
def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56): |
|
if not types: |
|
return {} |
|
filters = deepcopy(filters) |
|
filters["knowledge_graph_kwd"] = "entity" |
|
filters["entity_type_kwd"] = types |
|
ordr = OrderByExpr() |
|
ordr.desc("rank_flt") |
|
es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N, |
|
idxnms, kb_ids) |
|
return self._ent_info_from_(es_res, 0) |
|
|
|
def retrieval(self, question: str, |
|
tenant_ids: str | list[str], |
|
kb_ids: list[str], |
|
emb_mdl, |
|
llm, |
|
max_token: int = 8196, |
|
ent_topn: int = 6, |
|
rel_topn: int = 6, |
|
comm_topn: int = 1, |
|
ent_sim_threshold: float = 0.3, |
|
rel_sim_threshold: float = 0.3, |
|
): |
|
qst = question |
|
filters = self.get_filters({"kb_ids": kb_ids}) |
|
if isinstance(tenant_ids, str): |
|
tenant_ids = tenant_ids.split(",") |
|
idxnms = [index_name(tid) for tid in tenant_ids] |
|
ty_kwds = [] |
|
ents = [] |
|
try: |
|
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids) |
|
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}") |
|
except Exception as e: |
|
logging.exception(e) |
|
ents = [qst] |
|
pass |
|
|
|
ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold) |
|
ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000) |
|
rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold) |
|
nhop_pathes = defaultdict(dict) |
|
for _, ent in ents_from_query.items(): |
|
nhops = ent.get("n_hop_ents", []) |
|
for nbr in nhops: |
|
path = nbr["path"] |
|
wts = nbr["weights"] |
|
for i in range(len(path) - 1): |
|
f, t = path[i], path[i + 1] |
|
if (f, t) in nhop_pathes: |
|
nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i) |
|
else: |
|
nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i) |
|
nhop_pathes[(f, t)]["pagerank"] = wts[i] |
|
|
|
logging.info("Retrieved entities: {}".format(list(ents_from_query.keys()))) |
|
logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys()))) |
|
logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys()))) |
|
logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys()))) |
|
|
|
|
|
for ent in ents_from_types.keys(): |
|
if ent not in ents_from_query: |
|
continue |
|
ents_from_query[ent]["sim"] *= 2 |
|
|
|
for (f, t) in rels_from_txt.keys(): |
|
pair = tuple(sorted([f, t])) |
|
s = 0 |
|
if pair in nhop_pathes: |
|
s += nhop_pathes[pair]["sim"] |
|
del nhop_pathes[pair] |
|
if f in ents_from_types: |
|
s += 1 |
|
if t in ents_from_types: |
|
s += 1 |
|
rels_from_txt[(f, t)]["sim"] *= s + 1 |
|
|
|
|
|
for (f, t) in nhop_pathes.keys(): |
|
s = 0 |
|
if f in ents_from_types: |
|
s += 1 |
|
if t in ents_from_types: |
|
s += 1 |
|
rels_from_txt[(f, t)] = { |
|
"sim": nhop_pathes[(f, t)]["sim"] * (s + 1), |
|
"pagerank": nhop_pathes[(f, t)]["pagerank"] |
|
} |
|
|
|
ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[ |
|
:ent_topn] |
|
rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[ |
|
:rel_topn] |
|
|
|
ents = [] |
|
relas = [] |
|
for n, ent in ents_from_query: |
|
ents.append({ |
|
"Entity": n, |
|
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]), |
|
"Description": json.loads(ent["description"]).get("description", "") |
|
}) |
|
max_token -= num_tokens_from_string(str(ents[-1])) |
|
if max_token <= 0: |
|
ents = ents[:-1] |
|
break |
|
|
|
for (f, t), rel in rels_from_txt: |
|
if not rel.get("description"): |
|
for tid in tenant_ids: |
|
rela = get_relation(tid, kb_ids, f, t) |
|
if rela: |
|
break |
|
else: |
|
continue |
|
rel["description"] = rela["description"] |
|
relas.append({ |
|
"From Entity": f, |
|
"To Entity": t, |
|
"Score": "%.2f" % (rel["sim"] * rel["pagerank"]), |
|
"Description": json.loads(ent["description"]).get("description", "") |
|
}) |
|
max_token -= num_tokens_from_string(str(relas[-1])) |
|
if max_token <= 0: |
|
relas = relas[:-1] |
|
break |
|
|
|
if ents: |
|
ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv()) |
|
else: |
|
ents = "" |
|
if relas: |
|
relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv()) |
|
else: |
|
relas = "" |
|
|
|
return { |
|
"chunk_id": get_uuid(), |
|
"content_ltks": "", |
|
"content_with_weight": ents + relas + self._community_retrival_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, |
|
comm_topn, max_token), |
|
"doc_id": "", |
|
"docnm_kwd": "Related content in Knowledge Graph", |
|
"kb_id": kb_ids, |
|
"important_kwd": [], |
|
"image_id": "", |
|
"similarity": 1., |
|
"vector_similarity": 1., |
|
"term_similarity": 0, |
|
"vector": [], |
|
"positions": [], |
|
} |
|
|
|
def _community_retrival_(self, entities, condition, kb_ids, idxnms, topn, max_token): |
|
|
|
fields = ["docnm_kwd", "content_with_weight"] |
|
odr = OrderByExpr() |
|
odr.desc("weight_flt") |
|
fltr = deepcopy(condition) |
|
fltr["knowledge_graph_kwd"] = "community_report" |
|
fltr["entities_kwd"] = entities |
|
comm_res = self.dataStore.search(fields, [], fltr, [], |
|
OrderByExpr(), 0, topn, idxnms, kb_ids) |
|
comm_res_fields = self.dataStore.getFields(comm_res, fields) |
|
txts = [] |
|
for ii, (_, row) in enumerate(comm_res_fields.items()): |
|
obj = json.loads(row["content_with_weight"]) |
|
txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format( |
|
ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"])) |
|
max_token -= num_tokens_from_string(str(txts[-1])) |
|
|
|
if not txts: |
|
return "" |
|
return "\n---- Community Report ----\n" + "\n".join(txts) |
|
|
|
|
|
if __name__ == "__main__": |
|
from api import settings |
|
import argparse |
|
from api.db import LLMType |
|
from api.db.services.knowledgebase_service import KnowledgebaseService |
|
from api.db.services.llm_service import LLMBundle |
|
from api.db.services.user_service import TenantService |
|
from rag.nlp import search |
|
|
|
settings.init_settings() |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) |
|
parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True) |
|
parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True) |
|
args = parser.parse_args() |
|
|
|
kb_id = args.kb_id |
|
_, tenant = TenantService.get_by_id(args.tenant_id) |
|
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) |
|
_, kb = KnowledgebaseService.get_by_id(kb_id) |
|
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) |
|
|
|
kg = KGSearch(settings.docStoreConn) |
|
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, |
|
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)) |
|
|