Kevin Hu commited on
Commit
1ca7adb
·
1 Parent(s): bdb8bf3

fix term weight issue (#3306)

Browse files

### What problem does this PR solve?

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (2) hide show
  1. rag/benchmark.py +18 -10
  2. rag/nlp/search.py +2 -2
rag/benchmark.py CHANGED
@@ -34,12 +34,13 @@ from tqdm import tqdm
34
 
35
  class Benchmark:
36
  def __init__(self, kb_id):
37
- e, kb = KnowledgebaseService.get_by_id(kb_id)
38
- self.similarity_threshold = kb.similarity_threshold
39
- self.vector_similarity_weight = kb.vector_similarity_weight
40
- self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
41
 
42
  def _get_benchmarks(self, query, dataset_idxnm, count=16):
 
43
  req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
44
  sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
45
  return sres
@@ -48,11 +49,15 @@ class Benchmark:
48
  run = defaultdict(dict)
49
  query_list = list(qrels.keys())
50
  for query in query_list:
51
- sres = self._get_benchmarks(query, dataset_idxnm)
52
- sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
53
- self.vector_similarity_weight)
54
- for index, id in enumerate(sres.ids):
55
- run[query][id] = sim[index]
 
 
 
 
56
  return run
57
 
58
  def embedding(self, docs, batch_size=16):
@@ -99,7 +104,8 @@ class Benchmark:
99
  query = data.iloc[i]['query']
100
  for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
101
  d = {
102
- "id": get_uuid()
 
103
  }
104
  tokenize(d, text, "english")
105
  docs.append(d)
@@ -208,6 +214,8 @@ class Benchmark:
208
  scores = sorted(scores, key=lambda kk: kk[1])
209
  for score in scores[:10]:
210
  f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
 
 
211
  print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
212
 
213
  def __call__(self, dataset, file_path, miracl_corpus=''):
 
34
 
35
  class Benchmark:
36
  def __init__(self, kb_id):
37
+ e, self.kb = KnowledgebaseService.get_by_id(kb_id)
38
+ self.similarity_threshold = self.kb.similarity_threshold
39
+ self.vector_similarity_weight = self.kb.vector_similarity_weight
40
+ self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
41
 
42
  def _get_benchmarks(self, query, dataset_idxnm, count=16):
43
+
44
  req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
45
  sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
46
  return sres
 
49
  run = defaultdict(dict)
50
  query_list = list(qrels.keys())
51
  for query in query_list:
52
+
53
+ ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""),
54
+ [self.kb.id], 0, 30,
55
+ 0.0, self.vector_similarity_weight)
56
+ for c in ranks["chunks"]:
57
+ if "vector" in c:
58
+ del c["vector"]
59
+ run[query][c["chunk_id"]] = c["similarity"]
60
+
61
  return run
62
 
63
  def embedding(self, docs, batch_size=16):
 
104
  query = data.iloc[i]['query']
105
  for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
106
  d = {
107
+ "id": get_uuid(),
108
+ "kb_id": self.kb.id
109
  }
110
  tokenize(d, text, "english")
111
  docs.append(d)
 
214
  scores = sorted(scores, key=lambda kk: kk[1])
215
  for score in scores[:10]:
216
  f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
217
+ json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
218
+ json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
219
  print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
220
 
221
  def __call__(self, dataset, file_path, miracl_corpus=''):
rag/nlp/search.py CHANGED
@@ -211,8 +211,8 @@ class Dealer:
211
  continue
212
  if not isinstance(v, type("")):
213
  m[n] = str(m[n])
214
- if n.find("tks") > 0:
215
- m[n] = rmSpace(m[n])
216
 
217
  if m:
218
  res[d["id"]] = m
 
211
  continue
212
  if not isinstance(v, type("")):
213
  m[n] = str(m[n])
214
+ #if n.find("tks") > 0:
215
+ # m[n] = rmSpace(m[n])
216
 
217
  if m:
218
  res[d["id"]] = m