Kevin Hu commited on
Commit
2be6429
·
1 Parent(s): 4a6bb1f

Make infinity able to cal embedding sim only. (#4644)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/utils/infinity_conn.py +18 -3
rag/utils/infinity_conn.py CHANGED
@@ -273,9 +273,22 @@ class InfinityConnection(DocStoreConnection):
273
  for essential_field in ["id"]:
274
  if essential_field not in selectFields:
275
  selectFields.append(essential_field)
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  if matchExprs:
277
- for essential_field in ["score()", PAGERANK_FLD]:
278
- selectFields.append(essential_field)
279
 
280
  # Prepare expressions common to all tables
281
  filter_cond = None
@@ -364,7 +377,9 @@ class InfinityConnection(DocStoreConnection):
364
  self.connPool.release_conn(inf_conn)
365
  res = concat_dataframes(df_list, selectFields)
366
  if matchExprs:
367
- res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
 
 
368
  res = res.limit(limit)
369
  logger.debug(f"INFINITY search final result: {str(res)}")
370
  return res, total_hits_count
 
273
  for essential_field in ["id"]:
274
  if essential_field not in selectFields:
275
  selectFields.append(essential_field)
276
+ score_func = ""
277
+ score_column = ""
278
+ for matchExpr in matchExprs:
279
+ if isinstance(matchExpr, MatchTextExpr):
280
+ score_func = "score()"
281
+ score_column = "SCORE"
282
+ break
283
+ if not score_func:
284
+ for matchExpr in matchExprs:
285
+ if isinstance(matchExpr, MatchDenseExpr):
286
+ score_func = "similarity()"
287
+ score_column = "SIMILARITY"
288
+ break
289
  if matchExprs:
290
+ selectFields.append(score_func)
291
+ selectFields.append(PAGERANK_FLD)
292
 
293
  # Prepare expressions common to all tables
294
  filter_cond = None
 
377
  self.connPool.release_conn(inf_conn)
378
  res = concat_dataframes(df_list, selectFields)
379
  if matchExprs:
380
+ res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
381
+ if score_column and score_column != "SCORE":
382
+ res = res.rename({score_column: "SCORE"})
383
  res = res.limit(limit)
384
  logger.debug(f"INFINITY search final result: {str(res)}")
385
  return res, total_hits_count