zhichyu commited on
Commit
faa9f3e
·
1 Parent(s): a3da325

Added pagerank support to infinity (#4059)

Browse files

### What problem does this PR solve?

Added pagerank support to infinity

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (2) hide show
  1. api/apps/kb_app.py +1 -0
  2. rag/utils/infinity_conn.py +17 -9
api/apps/kb_app.py CHANGED
@@ -107,6 +107,7 @@ def update():
107
  settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
108
  search.index_name(kb.tenant_id), kb.id)
109
  else:
 
110
  settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
111
  search.index_name(kb.tenant_id), kb.id)
112
 
 
107
  settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
108
  search.index_name(kb.tenant_id), kb.id)
109
  else:
110
+ # Elasticsearch requires pagerank_fea be non-zero!
111
  settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
112
  search.index_name(kb.tenant_id), kb.id)
113
 
rag/utils/infinity_conn.py CHANGED
@@ -46,13 +46,14 @@ def equivalent_condition_to_str(condition: dict) -> str|None:
46
  cond.append(f"{k}='{v}'")
47
  else:
48
  cond.append(f"{k}={str(v)}")
49
- return " AND ".join(cond) if cond else None
50
 
51
 
52
  def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
53
  """
54
  Concatenate multiple dataframes into one.
55
  """
 
56
  if df_list:
57
  return pl.concat(df_list)
58
  schema = dict()
@@ -246,8 +247,9 @@ class InfinityConnection(DocStoreConnection):
246
  db_instance = inf_conn.get_database(self.dbName)
247
  df_list = list()
248
  table_list = list()
249
- if "id" not in selectFields:
250
- selectFields.append("id")
 
251
 
252
  # Prepare expressions common to all tables
253
  filter_cond = None
@@ -331,10 +333,13 @@ class InfinityConnection(DocStoreConnection):
331
  kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
332
  if extra_result:
333
  total_hits_count += int(extra_result["total_hits_count"])
 
334
  df_list.append(kb_res)
335
  self.connPool.release_conn(inf_conn)
336
  res = concat_dataframes(df_list, selectFields)
337
- logger.debug(f"INFINITY search tables: {str(table_list)}, result: {str(res)}")
 
 
338
  return res, total_hits_count
339
 
340
  def get(
@@ -350,12 +355,10 @@ class InfinityConnection(DocStoreConnection):
350
  table_list.append(table_name)
351
  table_instance = db_instance.get_table(table_name)
352
  kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
353
- if len(kb_res) != 0 and kb_res.shape[0] > 0:
354
- df_list.append(kb_res)
355
-
356
  self.connPool.release_conn(inf_conn)
357
  res = concat_dataframes(df_list, ["id"])
358
- logger.debug(f"INFINITY get tables: {str(table_list)}, result: {str(res)}")
359
  res_fields = self.getFields(res, res.columns)
360
  return res_fields.get(chunkId, None)
361
 
@@ -421,8 +424,10 @@ class InfinityConnection(DocStoreConnection):
421
  db_instance = inf_conn.get_database(self.dbName)
422
  table_name = f"{indexName}_{knowledgebaseId}"
423
  table_instance = db_instance.get_table(table_name)
 
 
424
  filter = equivalent_condition_to_str(condition)
425
- for k, v in newValue.items():
426
  if k.endswith("_kwd") and isinstance(v, list):
427
  newValue[k] = " ".join(v)
428
  elif k == 'kb_id':
@@ -435,6 +440,9 @@ class InfinityConnection(DocStoreConnection):
435
  elif k in ["page_num_int", "top_int"]:
436
  assert isinstance(v, list)
437
  newValue[k] = "_".join(f"{num:08x}" for num in v)
 
 
 
438
  logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
439
  table_instance.update(filter, newValue)
440
  self.connPool.release_conn(inf_conn)
 
46
  cond.append(f"{k}='{v}'")
47
  else:
48
  cond.append(f"{k}={str(v)}")
49
+ return " AND ".join(cond) if cond else "1=1"
50
 
51
 
52
  def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
53
  """
54
  Concatenate multiple dataframes into one.
55
  """
56
+ df_list = [df for df in df_list if not df.is_empty()]
57
  if df_list:
58
  return pl.concat(df_list)
59
  schema = dict()
 
247
  db_instance = inf_conn.get_database(self.dbName)
248
  df_list = list()
249
  table_list = list()
250
+ for essential_field in ["id", "score()", "pagerank_fea"]:
251
+ if essential_field not in selectFields:
252
+ selectFields.append(essential_field)
253
 
254
  # Prepare expressions common to all tables
255
  filter_cond = None
 
333
  kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
334
  if extra_result:
335
  total_hits_count += int(extra_result["total_hits_count"])
336
+ logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
337
  df_list.append(kb_res)
338
  self.connPool.release_conn(inf_conn)
339
  res = concat_dataframes(df_list, selectFields)
340
+ res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
341
+ res = res.limit(limit)
342
+ logger.debug(f"INFINITY search final result: {str(res)}")
343
  return res, total_hits_count
344
 
345
  def get(
 
355
  table_list.append(table_name)
356
  table_instance = db_instance.get_table(table_name)
357
  kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
358
+ logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
359
+ df_list.append(kb_res)
 
360
  self.connPool.release_conn(inf_conn)
361
  res = concat_dataframes(df_list, ["id"])
 
362
  res_fields = self.getFields(res, res.columns)
363
  return res_fields.get(chunkId, None)
364
 
 
424
  db_instance = inf_conn.get_database(self.dbName)
425
  table_name = f"{indexName}_{knowledgebaseId}"
426
  table_instance = db_instance.get_table(table_name)
427
+ if "exist" in condition:
428
+ del condition["exist"]
429
  filter = equivalent_condition_to_str(condition)
430
+ for k, v in list(newValue.items()):
431
  if k.endswith("_kwd") and isinstance(v, list):
432
  newValue[k] = " ".join(v)
433
  elif k == 'kb_id':
 
440
  elif k in ["page_num_int", "top_int"]:
441
  assert isinstance(v, list)
442
  newValue[k] = "_".join(f"{num:08x}" for num in v)
443
+ elif k == "remove" and v in ["pagerank_fea"]:
444
+ del newValue[k]
445
+ newValue[v] = 0
446
  logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
447
  table_instance.update(filter, newValue)
448
  self.connPool.release_conn(inf_conn)