Added pagerank support to infinity (#4059)
Browse files### What problem does this PR solve?
Added pagerank support to infinity
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/kb_app.py +1 -0
- rag/utils/infinity_conn.py +17 -9
api/apps/kb_app.py
CHANGED
@@ -107,6 +107,7 @@ def update():
|
|
107 |
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
108 |
search.index_name(kb.tenant_id), kb.id)
|
109 |
else:
|
|
|
110 |
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
111 |
search.index_name(kb.tenant_id), kb.id)
|
112 |
|
|
|
107 |
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
108 |
search.index_name(kb.tenant_id), kb.id)
|
109 |
else:
|
110 |
+
# Elasticsearch requires pagerank_fea be non-zero!
|
111 |
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
112 |
search.index_name(kb.tenant_id), kb.id)
|
113 |
|
rag/utils/infinity_conn.py
CHANGED
@@ -46,13 +46,14 @@ def equivalent_condition_to_str(condition: dict) -> str|None:
|
|
46 |
cond.append(f"{k}='{v}'")
|
47 |
else:
|
48 |
cond.append(f"{k}={str(v)}")
|
49 |
-
return " AND ".join(cond) if cond else
|
50 |
|
51 |
|
52 |
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
|
53 |
"""
|
54 |
Concatenate multiple dataframes into one.
|
55 |
"""
|
|
|
56 |
if df_list:
|
57 |
return pl.concat(df_list)
|
58 |
schema = dict()
|
@@ -246,8 +247,9 @@ class InfinityConnection(DocStoreConnection):
|
|
246 |
db_instance = inf_conn.get_database(self.dbName)
|
247 |
df_list = list()
|
248 |
table_list = list()
|
249 |
-
|
250 |
-
selectFields
|
|
|
251 |
|
252 |
# Prepare expressions common to all tables
|
253 |
filter_cond = None
|
@@ -331,10 +333,13 @@ class InfinityConnection(DocStoreConnection):
|
|
331 |
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
|
332 |
if extra_result:
|
333 |
total_hits_count += int(extra_result["total_hits_count"])
|
|
|
334 |
df_list.append(kb_res)
|
335 |
self.connPool.release_conn(inf_conn)
|
336 |
res = concat_dataframes(df_list, selectFields)
|
337 |
-
|
|
|
|
|
338 |
return res, total_hits_count
|
339 |
|
340 |
def get(
|
@@ -350,12 +355,10 @@ class InfinityConnection(DocStoreConnection):
|
|
350 |
table_list.append(table_name)
|
351 |
table_instance = db_instance.get_table(table_name)
|
352 |
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
self.connPool.release_conn(inf_conn)
|
357 |
res = concat_dataframes(df_list, ["id"])
|
358 |
-
logger.debug(f"INFINITY get tables: {str(table_list)}, result: {str(res)}")
|
359 |
res_fields = self.getFields(res, res.columns)
|
360 |
return res_fields.get(chunkId, None)
|
361 |
|
@@ -421,8 +424,10 @@ class InfinityConnection(DocStoreConnection):
|
|
421 |
db_instance = inf_conn.get_database(self.dbName)
|
422 |
table_name = f"{indexName}_{knowledgebaseId}"
|
423 |
table_instance = db_instance.get_table(table_name)
|
|
|
|
|
424 |
filter = equivalent_condition_to_str(condition)
|
425 |
-
for k, v in newValue.items():
|
426 |
if k.endswith("_kwd") and isinstance(v, list):
|
427 |
newValue[k] = " ".join(v)
|
428 |
elif k == 'kb_id':
|
@@ -435,6 +440,9 @@ class InfinityConnection(DocStoreConnection):
|
|
435 |
elif k in ["page_num_int", "top_int"]:
|
436 |
assert isinstance(v, list)
|
437 |
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
|
|
|
|
|
|
438 |
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
439 |
table_instance.update(filter, newValue)
|
440 |
self.connPool.release_conn(inf_conn)
|
|
|
46 |
cond.append(f"{k}='{v}'")
|
47 |
else:
|
48 |
cond.append(f"{k}={str(v)}")
|
49 |
+
return " AND ".join(cond) if cond else "1=1"
|
50 |
|
51 |
|
52 |
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
|
53 |
"""
|
54 |
Concatenate multiple dataframes into one.
|
55 |
"""
|
56 |
+
df_list = [df for df in df_list if not df.is_empty()]
|
57 |
if df_list:
|
58 |
return pl.concat(df_list)
|
59 |
schema = dict()
|
|
|
247 |
db_instance = inf_conn.get_database(self.dbName)
|
248 |
df_list = list()
|
249 |
table_list = list()
|
250 |
+
for essential_field in ["id", "score()", "pagerank_fea"]:
|
251 |
+
if essential_field not in selectFields:
|
252 |
+
selectFields.append(essential_field)
|
253 |
|
254 |
# Prepare expressions common to all tables
|
255 |
filter_cond = None
|
|
|
333 |
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
|
334 |
if extra_result:
|
335 |
total_hits_count += int(extra_result["total_hits_count"])
|
336 |
+
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
337 |
df_list.append(kb_res)
|
338 |
self.connPool.release_conn(inf_conn)
|
339 |
res = concat_dataframes(df_list, selectFields)
|
340 |
+
res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
|
341 |
+
res = res.limit(limit)
|
342 |
+
logger.debug(f"INFINITY search final result: {str(res)}")
|
343 |
return res, total_hits_count
|
344 |
|
345 |
def get(
|
|
|
355 |
table_list.append(table_name)
|
356 |
table_instance = db_instance.get_table(table_name)
|
357 |
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
358 |
+
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
359 |
+
df_list.append(kb_res)
|
|
|
360 |
self.connPool.release_conn(inf_conn)
|
361 |
res = concat_dataframes(df_list, ["id"])
|
|
|
362 |
res_fields = self.getFields(res, res.columns)
|
363 |
return res_fields.get(chunkId, None)
|
364 |
|
|
|
424 |
db_instance = inf_conn.get_database(self.dbName)
|
425 |
table_name = f"{indexName}_{knowledgebaseId}"
|
426 |
table_instance = db_instance.get_table(table_name)
|
427 |
+
if "exist" in condition:
|
428 |
+
del condition["exist"]
|
429 |
filter = equivalent_condition_to_str(condition)
|
430 |
+
for k, v in list(newValue.items()):
|
431 |
if k.endswith("_kwd") and isinstance(v, list):
|
432 |
newValue[k] = " ".join(v)
|
433 |
elif k == 'kb_id':
|
|
|
440 |
elif k in ["page_num_int", "top_int"]:
|
441 |
assert isinstance(v, list)
|
442 |
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
443 |
+
elif k == "remove" and v in ["pagerank_fea"]:
|
444 |
+
del newValue[k]
|
445 |
+
newValue[v] = 0
|
446 |
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
447 |
table_instance.update(filter, newValue)
|
448 |
self.connPool.release_conn(inf_conn)
|