|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
from copy import deepcopy |
|
|
|
import pandas as pd |
|
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr |
|
|
|
from rag.nlp.search import Dealer |
|
|
|
|
|
class KGSearch(Dealer): |
|
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False): |
|
def merge_into_first(sres, title="") -> dict[str, str]: |
|
if not sres: |
|
return {} |
|
content_with_weight = "" |
|
df, texts = [],[] |
|
for d in sres.values(): |
|
try: |
|
df.append(json.loads(d["content_with_weight"])) |
|
except Exception: |
|
texts.append(d["content_with_weight"]) |
|
if df: |
|
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv() |
|
else: |
|
content_with_weight = title + "\n" + "\n".join(texts) |
|
first_id = "" |
|
first_source = {} |
|
for k, v in sres.items(): |
|
first_id = id |
|
first_source = deepcopy(v) |
|
break |
|
first_source["content_with_weight"] = content_with_weight |
|
first_id = next(iter(sres)) |
|
return {first_id: first_source} |
|
|
|
qst = req.get("question", "") |
|
matchText, keywords = self.qryr.question(qst, min_match=0.05) |
|
condition = self.get_filters(req) |
|
|
|
|
|
condition.update({"knowledge_graph_kwd": ["entity"]}) |
|
assert emb_mdl, "No embedding model selected" |
|
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1)) |
|
q_vec = matchDense.embedding_data |
|
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", |
|
"doc_id", f"q_{len(q_vec)}_vec", "position_int", "name_kwd", |
|
"available_int", "content_with_weight", |
|
"weight_int", "weight_flt" |
|
]) |
|
|
|
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"}) |
|
|
|
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids) |
|
ent_res_fields = self.dataStore.getFields(ent_res, src) |
|
entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")] |
|
ent_ids = self.dataStore.getChunkIds(ent_res) |
|
ent_content = merge_into_first(ent_res_fields, "-Entities-") |
|
if ent_content: |
|
ent_ids = list(ent_content.keys()) |
|
|
|
|
|
condition = self.get_filters(req) |
|
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]}) |
|
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids) |
|
comm_res_fields = self.dataStore.getFields(comm_res, src) |
|
comm_ids = self.dataStore.getChunkIds(comm_res) |
|
comm_content = merge_into_first(comm_res_fields, "-Community Report-") |
|
if comm_content: |
|
comm_ids = list(comm_content.keys()) |
|
|
|
|
|
condition = self.get_filters(req) |
|
condition.update({"knowledge_graph_kwd": ["text"]}) |
|
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids) |
|
txt_res_fields = self.dataStore.getFields(txt_res, src) |
|
txt_ids = self.dataStore.getChunkIds(txt_res) |
|
txt_content = merge_into_first(txt_res_fields, "-Original Content-") |
|
if txt_content: |
|
txt_ids = list(txt_content.keys()) |
|
|
|
return self.SearchResult( |
|
total=len(ent_ids) + len(comm_ids) + len(txt_ids), |
|
ids=[*ent_ids, *comm_ids, *txt_ids], |
|
query_vector=q_vec, |
|
highlight=None, |
|
field={**ent_content, **comm_content, **txt_content}, |
|
keywords=[] |
|
) |
|
|