KevinHuSh commited on
Commit
c037a22
·
1 Parent(s): 459ac83

add rerank model (#969)

Browse files

### What problem does this PR solve?

feat: add rerank models to the project #724 #162

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/chunk_app.py CHANGED
@@ -257,8 +257,15 @@ def retrieval_test():
257
 
258
  embd_mdl = TenantLLMService.model_instance(
259
  kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
260
- ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
261
- vector_similarity_weight, top, doc_ids)
 
 
 
 
 
 
 
262
  for c in ranks["chunks"]:
263
  if "vector" in c:
264
  del c["vector"]
 
257
 
258
  embd_mdl = TenantLLMService.model_instance(
259
  kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
260
+
261
+ rerank_mdl = None
262
+ if req.get("rerank_id"):
263
+ rerank_mdl = TenantLLMService.model_instance(
264
+ kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
265
+
266
+ ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
267
+ similarity_threshold, vector_similarity_weight, top,
268
+ doc_ids, rerank_mdl=rerank_mdl)
269
  for c in ranks["chunks"]:
270
  if "vector" in c:
271
  del c["vector"]
api/apps/dialog_app.py CHANGED
@@ -33,6 +33,9 @@ def set_dialog():
33
  name = req.get("name", "New Dialog")
34
  description = req.get("description", "A helpful Dialog")
35
  top_n = req.get("top_n", 6)
 
 
 
36
  similarity_threshold = req.get("similarity_threshold", 0.1)
37
  vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
38
  llm_setting = req.get("llm_setting", {})
@@ -83,6 +86,8 @@ def set_dialog():
83
  "llm_setting": llm_setting,
84
  "prompt_config": prompt_config,
85
  "top_n": top_n,
 
 
86
  "similarity_threshold": similarity_threshold,
87
  "vector_similarity_weight": vector_similarity_weight
88
  }
 
33
  name = req.get("name", "New Dialog")
34
  description = req.get("description", "A helpful Dialog")
35
  top_n = req.get("top_n", 6)
36
+ top_k = req.get("top_k", 1024)
37
+ rerank_id = req.get("rerank_id", "")
38
+ if not rerank_id: req["rerank_id"] = ""
39
  similarity_threshold = req.get("similarity_threshold", 0.1)
40
  vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
41
  llm_setting = req.get("llm_setting", {})
 
86
  "llm_setting": llm_setting,
87
  "prompt_config": prompt_config,
88
  "top_n": top_n,
89
+ "top_k": top_k,
90
+ "rerank_id": rerank_id,
91
  "similarity_threshold": similarity_threshold,
92
  "vector_similarity_weight": vector_similarity_weight
93
  }
api/apps/llm_app.py CHANGED
@@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
20
  from api.db import StatusEnum, LLMType
21
  from api.db.db_models import TenantLLM
22
  from api.utils.api_utils import get_json_result
23
- from rag.llm import EmbeddingModel, ChatModel
24
 
25
 
26
  @manager.route('/factories', methods=['GET'])
@@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
28
  def factories():
29
  try:
30
  fac = LLMFactoriesService.get_all()
31
- return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
32
  except Exception as e:
33
  return server_error_response(e)
34
 
@@ -64,6 +64,16 @@ def set_api_key():
64
  except Exception as e:
65
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
66
  e)
 
 
 
 
 
 
 
 
 
 
67
 
68
  if msg:
69
  return get_data_error_result(retmsg=msg)
@@ -199,7 +209,7 @@ def list_app():
199
  llms = [m.to_dict()
200
  for m in llms if m.status == StatusEnum.VALID.value]
201
  for m in llms:
202
- m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
203
 
204
  llm_set = set([m["llm_name"] for m in llms])
205
  for o in objs:
 
20
  from api.db import StatusEnum, LLMType
21
  from api.db.db_models import TenantLLM
22
  from api.utils.api_utils import get_json_result
23
+ from rag.llm import EmbeddingModel, ChatModel, RerankModel
24
 
25
 
26
  @manager.route('/factories', methods=['GET'])
 
28
  def factories():
29
  try:
30
  fac = LLMFactoriesService.get_all()
31
+ return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
32
  except Exception as e:
33
  return server_error_response(e)
34
 
 
64
  except Exception as e:
65
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
66
  e)
67
+ elif llm.model_type == LLMType.RERANK:
68
+ mdl = RerankModel[factory](
69
+ req["api_key"], llm.llm_name, base_url=req.get("base_url"))
70
+ try:
71
+ m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
72
+ if len(arr[0]) == 0 or tc == 0:
73
+ raise Exception("Fail")
74
+ except Exception as e:
75
+ msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
76
+ e)
77
 
78
  if msg:
79
  return get_data_error_result(retmsg=msg)
 
209
  llms = [m.to_dict()
210
  for m in llms if m.status == StatusEnum.VALID.value]
211
  for m in llms:
212
+ m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
213
 
214
  llm_set = set([m["llm_name"] for m in llms])
215
  for o in objs:
api/apps/user_app.py CHANGED
@@ -26,8 +26,9 @@ from api.db.services.llm_service import TenantLLMService, LLMService
26
  from api.utils.api_utils import server_error_response, validate_request
27
  from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
28
  from api.db import UserTenantRole, LLMType, FileType
29
- from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
30
- LLM_FACTORY, LLM_BASE_URL
 
31
  from api.db.services.user_service import UserService, TenantService, UserTenantService
32
  from api.db.services.file_service import FileService
33
  from api.settings import stat_logger
@@ -288,7 +289,8 @@ def user_register(user_id, user):
288
  "embd_id": EMBEDDING_MDL,
289
  "asr_id": ASR_MDL,
290
  "parser_ids": PARSERS,
291
- "img2txt_id": IMAGE2TEXT_MDL
 
292
  }
293
  usr_tenant = {
294
  "tenant_id": user_id,
 
26
  from api.utils.api_utils import server_error_response, validate_request
27
  from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
28
  from api.db import UserTenantRole, LLMType, FileType
29
+ from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
30
+ API_KEY, \
31
+ LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
32
  from api.db.services.user_service import UserService, TenantService, UserTenantService
33
  from api.db.services.file_service import FileService
34
  from api.settings import stat_logger
 
289
  "embd_id": EMBEDDING_MDL,
290
  "asr_id": ASR_MDL,
291
  "parser_ids": PARSERS,
292
+ "img2txt_id": IMAGE2TEXT_MDL,
293
+ "rerank_id": RERANK_MDL
294
  }
295
  usr_tenant = {
296
  "tenant_id": user_id,
api/db/__init__.py CHANGED
@@ -54,6 +54,7 @@ class LLMType(StrEnum):
54
  EMBEDDING = 'embedding'
55
  SPEECH2TEXT = 'speech2text'
56
  IMAGE2TEXT = 'image2text'
 
57
 
58
 
59
  class ChatStyle(StrEnum):
 
54
  EMBEDDING = 'embedding'
55
  SPEECH2TEXT = 'speech2text'
56
  IMAGE2TEXT = 'image2text'
57
+ RERANK = 'rerank'
58
 
59
 
60
  class ChatStyle(StrEnum):
api/db/db_models.py CHANGED
@@ -437,6 +437,10 @@ class Tenant(DataBaseModel):
437
  max_length=128,
438
  null=False,
439
  help_text="default image to text model ID")
 
 
 
 
440
  parser_ids = CharField(
441
  max_length=256,
442
  null=False,
@@ -771,11 +775,16 @@ class Dialog(DataBaseModel):
771
  similarity_threshold = FloatField(default=0.2)
772
  vector_similarity_weight = FloatField(default=0.3)
773
  top_n = IntegerField(default=6)
 
774
  do_refer = CharField(
775
  max_length=1,
776
  null=False,
777
  help_text="it needs to insert reference index into answer or not",
778
  default="1")
 
 
 
 
779
 
780
  kb_ids = JSONField(null=False, default=[])
781
  status = CharField(
@@ -825,11 +834,29 @@ class API4Conversation(DataBaseModel):
825
 
826
 
827
  def migrate_db():
828
- try:
829
  with DB.transaction():
830
  migrator = MySQLMigrator(DB)
831
- migrate(
832
- migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
833
- )
834
- except Exception as e:
835
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  max_length=128,
438
  null=False,
439
  help_text="default image to text model ID")
440
+ rerank_id = CharField(
441
+ max_length=128,
442
+ null=False,
443
+ help_text="default rerank model ID")
444
  parser_ids = CharField(
445
  max_length=256,
446
  null=False,
 
775
  similarity_threshold = FloatField(default=0.2)
776
  vector_similarity_weight = FloatField(default=0.3)
777
  top_n = IntegerField(default=6)
778
+ top_k = IntegerField(default=1024)
779
  do_refer = CharField(
780
  max_length=1,
781
  null=False,
782
  help_text="it needs to insert reference index into answer or not",
783
  default="1")
784
+ rerank_id = CharField(
785
+ max_length=128,
786
+ null=False,
787
+ help_text="default rerank model ID")
788
 
789
  kb_ids = JSONField(null=False, default=[])
790
  status = CharField(
 
834
 
835
 
836
  def migrate_db():
 
837
  with DB.transaction():
838
  migrator = MySQLMigrator(DB)
839
+ try:
840
+ migrate(
841
+ migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
842
+ )
843
+ except Exception as e:
844
+ pass
845
+ try:
846
+ migrate(
847
+ migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
848
+ )
849
+ except Exception as e:
850
+ pass
851
+ try:
852
+ migrate(
853
+ migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
854
+ )
855
+ except Exception as e:
856
+ pass
857
+ try:
858
+ migrate(
859
+ migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
860
+ )
861
+ except Exception as e:
862
+ pass
api/db/init_data.py CHANGED
@@ -142,7 +142,17 @@ factory_infos = [{
142
  "logo": "",
143
  "tags": "LLM,TEXT EMBEDDING",
144
  "status": "1",
145
- },
 
 
 
 
 
 
 
 
 
 
146
  # {
147
  # "name": "文心一言",
148
  # "logo": "",
@@ -367,6 +377,13 @@ def init_llm_factory():
367
  "max_tokens": 512,
368
  "model_type": LLMType.EMBEDDING.value
369
  },
 
 
 
 
 
 
 
370
  # ------------------------ DeepSeek -----------------------
371
  {
372
  "fid": factory_infos[8]["name"],
@@ -440,6 +457,85 @@ def init_llm_factory():
440
  "max_tokens": 512,
441
  "model_type": LLMType.EMBEDDING.value
442
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  ]
444
  for info in factory_infos:
445
  try:
 
142
  "logo": "",
143
  "tags": "LLM,TEXT EMBEDDING",
144
  "status": "1",
145
+ },{
146
+ "name": "Jina",
147
+ "logo": "",
148
+ "tags": "TEXT EMBEDDING, TEXT RE-RANK",
149
+ "status": "1",
150
+ },{
151
+ "name": "BAAI",
152
+ "logo": "",
153
+ "tags": "TEXT EMBEDDING, TEXT RE-RANK",
154
+ "status": "1",
155
+ }
156
  # {
157
  # "name": "文心一言",
158
  # "logo": "",
 
377
  "max_tokens": 512,
378
  "model_type": LLMType.EMBEDDING.value
379
  },
380
+ {
381
+ "fid": factory_infos[7]["name"],
382
+ "llm_name": "maidalun1020/bce-reranker-base_v1",
383
+ "tags": "RE-RANK, 8K",
384
+ "max_tokens": 8196,
385
+ "model_type": LLMType.RERANK.value
386
+ },
387
  # ------------------------ DeepSeek -----------------------
388
  {
389
  "fid": factory_infos[8]["name"],
 
457
  "max_tokens": 512,
458
  "model_type": LLMType.EMBEDDING.value
459
  },
460
+ # ------------------------ Jina -----------------------
461
+ {
462
+ "fid": factory_infos[11]["name"],
463
+ "llm_name": "jina-reranker-v1-base-en",
464
+ "tags": "RE-RANK,8k",
465
+ "max_tokens": 8196,
466
+ "model_type": LLMType.RERANK.value
467
+ },
468
+ {
469
+ "fid": factory_infos[11]["name"],
470
+ "llm_name": "jina-reranker-v1-turbo-en",
471
+ "tags": "RE-RANK,8k",
472
+ "max_tokens": 8196,
473
+ "model_type": LLMType.RERANK.value
474
+ },
475
+ {
476
+ "fid": factory_infos[11]["name"],
477
+ "llm_name": "jina-reranker-v1-tiny-en",
478
+ "tags": "RE-RANK,8k",
479
+ "max_tokens": 8196,
480
+ "model_type": LLMType.RERANK.value
481
+ },
482
+ {
483
+ "fid": factory_infos[11]["name"],
484
+ "llm_name": "jina-colbert-v1-en",
485
+ "tags": "RE-RANK,8k",
486
+ "max_tokens": 8196,
487
+ "model_type": LLMType.RERANK.value
488
+ },
489
+ {
490
+ "fid": factory_infos[11]["name"],
491
+ "llm_name": "jina-embeddings-v2-base-en",
492
+ "tags": "TEXT EMBEDDING",
493
+ "max_tokens": 8196,
494
+ "model_type": LLMType.EMBEDDING.value
495
+ },
496
+ {
497
+ "fid": factory_infos[11]["name"],
498
+ "llm_name": "jina-embeddings-v2-base-de",
499
+ "tags": "TEXT EMBEDDING",
500
+ "max_tokens": 8196,
501
+ "model_type": LLMType.EMBEDDING.value
502
+ },
503
+ {
504
+ "fid": factory_infos[11]["name"],
505
+ "llm_name": "jina-embeddings-v2-base-es",
506
+ "tags": "TEXT EMBEDDING",
507
+ "max_tokens": 8196,
508
+ "model_type": LLMType.EMBEDDING.value
509
+ },
510
+ {
511
+ "fid": factory_infos[11]["name"],
512
+ "llm_name": "jina-embeddings-v2-base-code",
513
+ "tags": "TEXT EMBEDDING",
514
+ "max_tokens": 8196,
515
+ "model_type": LLMType.EMBEDDING.value
516
+ },
517
+ {
518
+ "fid": factory_infos[11]["name"],
519
+ "llm_name": "jina-embeddings-v2-base-zh",
520
+ "tags": "TEXT EMBEDDING",
521
+ "max_tokens": 8196,
522
+ "model_type": LLMType.EMBEDDING.value
523
+ },
524
+ # ------------------------ BAAI -----------------------
525
+ {
526
+ "fid": factory_infos[12]["name"],
527
+ "llm_name": "BAAI/bge-large-zh-v1.5",
528
+ "tags": "TEXT EMBEDDING,",
529
+ "max_tokens": 1024,
530
+ "model_type": LLMType.EMBEDDING.value
531
+ },
532
+ {
533
+ "fid": factory_infos[12]["name"],
534
+ "llm_name": "BAAI/bge-reranker-v2-m3",
535
+ "tags": "LLM,CHAT,",
536
+ "max_tokens": 16385,
537
+ "model_type": LLMType.RERANK.value
538
+ },
539
  ]
540
  for info in factory_infos:
541
  try:
api/db/services/dialog_service.py CHANGED
@@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
115
  if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
116
  kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
117
  else:
 
 
 
118
  kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
119
  dialog.similarity_threshold,
120
  dialog.vector_similarity_weight,
121
  doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
122
- top=1024, aggs=False)
123
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
124
  chat_logger.info(
125
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
@@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs):
130
 
131
  kwargs["knowledge"] = "\n".join(knowledges)
132
  gen_conf = dialog.llm_setting
133
-
134
  msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
135
  msg.extend([{"role": m["role"], "content": m["content"]}
136
  for m in messages if m["role"] != "system"])
 
115
  if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
116
  kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
117
  else:
118
+ rerank_mdl = None
119
+ if dialog.rerank_id:
120
+ rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
121
  kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
122
  dialog.similarity_threshold,
123
  dialog.vector_similarity_weight,
124
  doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
125
+ top=1024, aggs=False, rerank_mdl=rerank_mdl)
126
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
127
  chat_logger.info(
128
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
 
133
 
134
  kwargs["knowledge"] = "\n".join(knowledges)
135
  gen_conf = dialog.llm_setting
136
+
137
  msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
138
  msg.extend([{"role": m["role"], "content": m["content"]}
139
  for m in messages if m["role"] != "system"])
api/db/services/llm_service.py CHANGED
@@ -15,7 +15,7 @@
15
  #
16
  from api.db.services.user_service import TenantService
17
  from api.settings import database_logger
18
- from rag.llm import EmbeddingModel, CvModel, ChatModel
19
  from api.db import LLMType
20
  from api.db.db_models import DB, UserTenant
21
  from api.db.db_models import LLMFactories, LLM, TenantLLM
@@ -73,21 +73,25 @@ class TenantLLMService(CommonService):
73
  mdlnm = tenant.img2txt_id
74
  elif llm_type == LLMType.CHAT.value:
75
  mdlnm = tenant.llm_id if not llm_name else llm_name
 
 
76
  else:
77
  assert False, "LLM type error"
78
 
79
  model_config = cls.get_api_key(tenant_id, mdlnm)
80
  if model_config: model_config = model_config.to_dict()
81
  if not model_config:
82
- if llm_type == LLMType.EMBEDDING.value:
83
  llm = LLMService.query(llm_name=llm_name)
84
- if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]:
85
  model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
86
  if not model_config:
87
  if llm_name == "flag-embedding":
88
  model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
89
  "llm_name": llm_name, "api_base": ""}
90
  else:
 
 
91
  raise LookupError("Model({}) not authorized".format(mdlnm))
92
 
93
  if llm_type == LLMType.EMBEDDING.value:
@@ -96,6 +100,12 @@ class TenantLLMService(CommonService):
96
  return EmbeddingModel[model_config["llm_factory"]](
97
  model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
98
 
 
 
 
 
 
 
99
  if llm_type == LLMType.IMAGE2TEXT.value:
100
  if model_config["llm_factory"] not in CvModel:
101
  return
@@ -125,14 +135,20 @@ class TenantLLMService(CommonService):
125
  mdlnm = tenant.img2txt_id
126
  elif llm_type == LLMType.CHAT.value:
127
  mdlnm = tenant.llm_id if not llm_name else llm_name
 
 
128
  else:
129
  assert False, "LLM type error"
130
 
131
  num = 0
132
- for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
133
- num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
134
- .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
135
- .execute()
 
 
 
 
136
  return num
137
 
138
  @classmethod
@@ -176,6 +192,14 @@ class LLMBundle(object):
176
  "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
177
  return emd, used_tokens
178
 
 
 
 
 
 
 
 
 
179
  def describe(self, image, max_tokens=300):
180
  txt, used_tokens = self.mdl.describe(image, max_tokens)
181
  if not TenantLLMService.increase_usage(
 
15
  #
16
  from api.db.services.user_service import TenantService
17
  from api.settings import database_logger
18
+ from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel
19
  from api.db import LLMType
20
  from api.db.db_models import DB, UserTenant
21
  from api.db.db_models import LLMFactories, LLM, TenantLLM
 
73
  mdlnm = tenant.img2txt_id
74
  elif llm_type == LLMType.CHAT.value:
75
  mdlnm = tenant.llm_id if not llm_name else llm_name
76
+ elif llm_type == LLMType.RERANK:
77
+ mdlnm = tenant.rerank_id if not llm_name else llm_name
78
  else:
79
  assert False, "LLM type error"
80
 
81
  model_config = cls.get_api_key(tenant_id, mdlnm)
82
  if model_config: model_config = model_config.to_dict()
83
  if not model_config:
84
+ if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
85
  llm = LLMService.query(llm_name=llm_name)
86
+ if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
87
  model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
88
  if not model_config:
89
  if llm_name == "flag-embedding":
90
  model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
91
  "llm_name": llm_name, "api_base": ""}
92
  else:
93
+ if not mdlnm:
94
+ raise LookupError(f"Type of {llm_type} model is not set.")
95
  raise LookupError("Model({}) not authorized".format(mdlnm))
96
 
97
  if llm_type == LLMType.EMBEDDING.value:
 
100
  return EmbeddingModel[model_config["llm_factory"]](
101
  model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
102
 
103
+ if llm_type == LLMType.RERANK:
104
+ if model_config["llm_factory"] not in RerankModel:
105
+ return
106
+ return RerankModel[model_config["llm_factory"]](
107
+ model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
108
+
109
  if llm_type == LLMType.IMAGE2TEXT.value:
110
  if model_config["llm_factory"] not in CvModel:
111
  return
 
135
  mdlnm = tenant.img2txt_id
136
  elif llm_type == LLMType.CHAT.value:
137
  mdlnm = tenant.llm_id if not llm_name else llm_name
138
+ elif llm_type == LLMType.RERANK:
139
+ mdlnm = tenant.llm_id if not llm_name else llm_name
140
  else:
141
  assert False, "LLM type error"
142
 
143
  num = 0
144
+ try:
145
+ for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
146
+ num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
147
+ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
148
+ .execute()
149
+ except Exception as e:
150
+ print(e)
151
+ pass
152
  return num
153
 
154
  @classmethod
 
192
  "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
193
  return emd, used_tokens
194
 
195
+ def similarity(self, query: str, texts: list):
196
+ sim, used_tokens = self.mdl.similarity(query, texts)
197
+ if not TenantLLMService.increase_usage(
198
+ self.tenant_id, self.llm_type, used_tokens):
199
+ database_logger.error(
200
+ "Can't update token usage for {}/RERANK".format(self.tenant_id))
201
+ return sim, used_tokens
202
+
203
  def describe(self, image, max_tokens=300):
204
  txt, used_tokens = self.mdl.describe(image, max_tokens)
205
  if not TenantLLMService.increase_usage(
api/db/services/user_service.py CHANGED
@@ -93,6 +93,7 @@ class TenantService(CommonService):
93
  cls.model.name,
94
  cls.model.llm_id,
95
  cls.model.embd_id,
 
96
  cls.model.asr_id,
97
  cls.model.img2txt_id,
98
  cls.model.parser_ids,
 
93
  cls.model.name,
94
  cls.model.llm_id,
95
  cls.model.embd_id,
96
+ cls.model.rerank_id,
97
  cls.model.asr_id,
98
  cls.model.img2txt_id,
99
  cls.model.parser_ids,
api/settings.py CHANGED
@@ -89,9 +89,22 @@ default_llm = {
89
  },
90
  "DeepSeek": {
91
  "chat_model": "deepseek-chat",
 
 
 
 
 
 
 
 
 
 
 
 
92
  "embedding_model": "BAAI/bge-large-zh-v1.5",
93
  "image2text_model": "",
94
  "asr_model": "",
 
95
  }
96
  }
97
  LLM = get_base_config("user_default_llm", {})
@@ -104,7 +117,8 @@ if LLM_FACTORY not in default_llm:
104
  f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
105
  LLM_FACTORY = "Tongyi-Qianwen"
106
  CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
107
- EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
 
108
  ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
109
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
110
 
 
89
  },
90
  "DeepSeek": {
91
  "chat_model": "deepseek-chat",
92
+ "embedding_model": "",
93
+ "image2text_model": "",
94
+ "asr_model": "",
95
+ },
96
+ "VolcEngine": {
97
+ "chat_model": "",
98
+ "embedding_model": "",
99
+ "image2text_model": "",
100
+ "asr_model": "",
101
+ },
102
+ "BAAI": {
103
+ "chat_model": "",
104
  "embedding_model": "BAAI/bge-large-zh-v1.5",
105
  "image2text_model": "",
106
  "asr_model": "",
107
+ "rerank_model": "BAAI/bge-reranker-v2-m3",
108
  }
109
  }
110
  LLM = get_base_config("user_default_llm", {})
 
117
  f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
118
  LLM_FACTORY = "Tongyi-Qianwen"
119
  CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
120
+ EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
121
+ RERANK_MDL = default_llm["BAAI"]["rerank_model"]
122
  ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
123
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
124
 
rag/llm/__init__.py CHANGED
@@ -16,18 +16,19 @@
16
  from .embedding_model import *
17
  from .chat_model import *
18
  from .cv_model import *
 
19
 
20
 
21
  EmbeddingModel = {
22
  "Ollama": OllamaEmbed,
23
  "OpenAI": OpenAIEmbed,
24
  "Xinference": XinferenceEmbed,
25
- "Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed,
26
  "ZHIPU-AI": ZhipuEmbed,
27
  "FastEmbed": FastEmbed,
28
  "Youdao": YoudaoEmbed,
29
- "DeepSeek": DefaultEmbedding,
30
- "BaiChuan": BaiChuanEmbed
31
  }
32
 
33
 
@@ -52,3 +53,9 @@ ChatModel = {
52
  "BaiChuan": BaiChuanChat
53
  }
54
 
 
 
 
 
 
 
 
16
  from .embedding_model import *
17
  from .chat_model import *
18
  from .cv_model import *
19
+ from .rerank_model import *
20
 
21
 
22
  EmbeddingModel = {
23
  "Ollama": OllamaEmbed,
24
  "OpenAI": OpenAIEmbed,
25
  "Xinference": XinferenceEmbed,
26
+ "Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed,
27
  "ZHIPU-AI": ZhipuEmbed,
28
  "FastEmbed": FastEmbed,
29
  "Youdao": YoudaoEmbed,
30
+ "BaiChuan": BaiChuanEmbed,
31
+ "BAAI": DefaultEmbedding
32
  }
33
 
34
 
 
53
  "BaiChuan": BaiChuanChat
54
  }
55
 
56
+
57
+ RerankModel = {
58
+ "BAAI": DefaultRerank,
59
+ "Jina": JinaRerank,
60
+ "Youdao": YoudaoRerank,
61
+ }
rag/llm/embedding_model.py CHANGED
@@ -13,8 +13,10 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  from typing import Optional
17
 
 
18
  from huggingface_hub import snapshot_download
19
  from zhipuai import ZhipuAI
20
  import os
@@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel
26
  import torch
27
  import numpy as np
28
 
29
- from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
30
  from rag.utils import num_tokens_from_string, truncate
31
 
32
- try:
33
- flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
34
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
35
- use_fp16=torch.cuda.is_available())
36
- except Exception as e:
37
- model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
38
- local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
39
- local_dir_use_symlinks=False)
40
- flag_model = FlagModel(model_dir,
41
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
42
- use_fp16=torch.cuda.is_available())
43
-
44
 
45
  class Base(ABC):
46
  def __init__(self, key, model_name):
@@ -54,7 +44,9 @@ class Base(ABC):
54
 
55
 
56
  class DefaultEmbedding(Base):
57
- def __init__(self, *args, **kwargs):
 
 
58
  """
59
  If you have trouble downloading HuggingFace models, -_^ this might help!!
60
 
@@ -66,7 +58,18 @@ class DefaultEmbedding(Base):
66
  ^_-
67
 
68
  """
69
- self.model = flag_model
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def encode(self, texts: list, batch_size=32):
72
  texts = [truncate(t, 2048) for t in texts]
@@ -75,12 +78,12 @@ class DefaultEmbedding(Base):
75
  token_count += num_tokens_from_string(t)
76
  res = []
77
  for i in range(0, len(texts), batch_size):
78
- res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
79
  return np.array(res), token_count
80
 
81
  def encode_queries(self, text: str):
82
  token_count = num_tokens_from_string(text)
83
- return self.model.encode_queries([text]).tolist()[0], token_count
84
 
85
 
86
  class OpenAIEmbed(Base):
@@ -189,16 +192,19 @@ class OllamaEmbed(Base):
189
 
190
 
191
  class FastEmbed(Base):
 
 
192
  def __init__(
193
- self,
194
- key: Optional[str] = None,
195
- model_name: str = "BAAI/bge-small-en-v1.5",
196
- cache_dir: Optional[str] = None,
197
- threads: Optional[int] = None,
198
- **kwargs,
199
  ):
200
  from fastembed import TextEmbedding
201
- self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
 
202
 
203
  def encode(self, texts: list, batch_size=32):
204
  # Using the internal tokenizer to encode the texts and get the total
@@ -265,3 +271,29 @@ class YoudaoEmbed(Base):
265
  def encode_queries(self, text):
266
  embds = YoudaoEmbed._client.encode([text])
267
  return np.array(embds[0]), num_tokens_from_string(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import re
17
  from typing import Optional
18
 
19
+ import requests
20
  from huggingface_hub import snapshot_download
21
  from zhipuai import ZhipuAI
22
  import os
 
28
  import torch
29
  import numpy as np
30
 
31
+ from api.utils.file_utils import get_home_cache_dir
32
  from rag.utils import num_tokens_from_string, truncate
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class Base(ABC):
36
  def __init__(self, key, model_name):
 
44
 
45
 
46
  class DefaultEmbedding(Base):
47
+ _model = None
48
+
49
+ def __init__(self, key, model_name, **kwargs):
50
  """
51
  If you have trouble downloading HuggingFace models, -_^ this might help!!
52
 
 
58
  ^_-
59
 
60
  """
61
+ if not DefaultEmbedding._model:
62
+ try:
63
+ self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
64
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
65
+ use_fp16=torch.cuda.is_available())
66
+ except Exception as e:
67
+ model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
68
+ local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
69
+ local_dir_use_symlinks=False)
70
+ self._model = FlagModel(model_dir,
71
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
72
+ use_fp16=torch.cuda.is_available())
73
 
74
  def encode(self, texts: list, batch_size=32):
75
  texts = [truncate(t, 2048) for t in texts]
 
78
  token_count += num_tokens_from_string(t)
79
  res = []
80
  for i in range(0, len(texts), batch_size):
81
+ res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
82
  return np.array(res), token_count
83
 
84
  def encode_queries(self, text: str):
85
  token_count = num_tokens_from_string(text)
86
+ return self._model.encode_queries([text]).tolist()[0], token_count
87
 
88
 
89
  class OpenAIEmbed(Base):
 
192
 
193
 
194
  class FastEmbed(Base):
195
+ _model = None
196
+
197
  def __init__(
198
+ self,
199
+ key: Optional[str] = None,
200
+ model_name: str = "BAAI/bge-small-en-v1.5",
201
+ cache_dir: Optional[str] = None,
202
+ threads: Optional[int] = None,
203
+ **kwargs,
204
  ):
205
  from fastembed import TextEmbedding
206
+ if not FastEmbed._model:
207
+ self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
208
 
209
  def encode(self, texts: list, batch_size=32):
210
  # Using the internal tokenizer to encode the texts and get the total
 
271
  def encode_queries(self, text):
272
  embds = YoudaoEmbed._client.encode([text])
273
  return np.array(embds[0]), num_tokens_from_string(text)
274
+
275
+
276
+ class JinaEmbed(Base):
277
+ def __init__(self, key, model_name="jina-embeddings-v2-base-zh",
278
+ base_url="https://api.jina.ai/v1/embeddings"):
279
+
280
+ self.base_url = "https://api.jina.ai/v1/embeddings"
281
+ self.headers = {
282
+ "Content-Type": "application/json",
283
+ "Authorization": f"Bearer {key}"
284
+ }
285
+ self.model_name = model_name
286
+
287
+ def encode(self, texts: list, batch_size=None):
288
+ texts = [truncate(t, 8196) for t in texts]
289
+ data = {
290
+ "model": self.model_name,
291
+ "input": texts,
292
+ 'encoding_type': 'float'
293
+ }
294
+ res = requests.post(self.base_url, headers=self.headers, json=data)
295
+ return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
296
+
297
+ def encode_queries(self, text):
298
+ embds, cnt = self.encode([text])
299
+ return np.array(embds[0]), cnt
rag/llm/rerank_model.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import re
17
+ import requests
18
+ import torch
19
+ from FlagEmbedding import FlagReranker
20
+ from huggingface_hub import snapshot_download
21
+ import os
22
+ from abc import ABC
23
+ import numpy as np
24
+ from api.utils.file_utils import get_home_cache_dir
25
+ from rag.utils import num_tokens_from_string, truncate
26
+
27
+
28
+ class Base(ABC):
29
+ def __init__(self, key, model_name):
30
+ pass
31
+
32
+ def similarity(self, query: str, texts: list):
33
+ raise NotImplementedError("Please implement encode method!")
34
+
35
+
36
+ class DefaultRerank(Base):
37
+ _model = None
38
+
39
+ def __init__(self, key, model_name, **kwargs):
40
+ """
41
+ If you have trouble downloading HuggingFace models, -_^ this might help!!
42
+
43
+ For Linux:
44
+ export HF_ENDPOINT=https://hf-mirror.com
45
+
46
+ For Windows:
47
+ Good luck
48
+ ^_-
49
+
50
+ """
51
+ if not DefaultRerank._model:
52
+ try:
53
+ self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
54
+ use_fp16=torch.cuda.is_available())
55
+ except Exception as e:
56
+ self._model = snapshot_download(repo_id=model_name,
57
+ local_dir=os.path.join(get_home_cache_dir(),
58
+ re.sub(r"^[a-zA-Z]+/", "", model_name)),
59
+ local_dir_use_symlinks=False)
60
+ self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
61
+ use_fp16=torch.cuda.is_available())
62
+
63
+ def similarity(self, query: str, texts: list):
64
+ pairs = [(query,truncate(t, 2048)) for t in texts]
65
+ token_count = 0
66
+ for _, t in pairs:
67
+ token_count += num_tokens_from_string(t)
68
+ batch_size = 32
69
+ res = []
70
+ for i in range(0, len(pairs), batch_size):
71
+ scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
72
+ res.extend(scores)
73
+ return np.array(res), token_count
74
+
75
+
76
+ class JinaRerank(Base):
77
+ def __init__(self, key, model_name="jina-reranker-v1-base-en",
78
+ base_url="https://api.jina.ai/v1/rerank"):
79
+ self.base_url = "https://api.jina.ai/v1/rerank"
80
+ self.headers = {
81
+ "Content-Type": "application/json",
82
+ "Authorization": f"Bearer {key}"
83
+ }
84
+ self.model_name = model_name
85
+
86
+ def similarity(self, query: str, texts: list):
87
+ texts = [truncate(t, 8196) for t in texts]
88
+ data = {
89
+ "model": self.model_name,
90
+ "query": query,
91
+ "documents": texts,
92
+ "top_n": len(texts)
93
+ }
94
+ res = requests.post(self.base_url, headers=self.headers, json=data)
95
+ return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
96
+
97
+
98
+ class YoudaoRerank(DefaultRerank):
99
+ _model = None
100
+
101
+ def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
102
+ from BCEmbedding import RerankerModel
103
+ if not YoudaoRerank._model:
104
+ try:
105
+ print("LOADING BCE...")
106
+ YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
107
+ get_home_cache_dir(),
108
+ re.sub(r"^[a-zA-Z]+/", "", model_name)))
109
+ except Exception as e:
110
+ YoudaoRerank._model = RerankerModel(
111
+ model_name_or_path=model_name.replace(
112
+ "maidalun1020", "InfiniFlow"))
113
+
rag/nlp/query.py CHANGED
@@ -54,7 +54,8 @@ class EsQueryer:
54
  if not self.isChinese(txt):
55
  tks = rag_tokenizer.tokenize(txt).split(" ")
56
  tks_w = self.tw.weights(tks)
57
- q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
 
58
  for i in range(1, len(tks_w)):
59
  q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
60
  if not q:
@@ -136,7 +137,11 @@ class EsQueryer:
136
  from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
137
  import numpy as np
138
  sims = CosineSimilarity([avec], bvecs)
 
 
 
139
 
 
140
  def toDict(tks):
141
  d = {}
142
  if isinstance(tks, str):
@@ -149,9 +154,7 @@ class EsQueryer:
149
 
150
  atks = toDict(atks)
151
  btkss = [toDict(tks) for tks in btkss]
152
- tksim = [self.similarity(atks, btks) for btks in btkss]
153
- return np.array(sims[0]) * vtweight + \
154
- np.array(tksim) * tkweight, tksim, sims[0]
155
 
156
  def similarity(self, qtwt, dtwt):
157
  if isinstance(dtwt, type("")):
 
54
  if not self.isChinese(txt):
55
  tks = rag_tokenizer.tokenize(txt).split(" ")
56
  tks_w = self.tw.weights(tks)
57
+ tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
58
+ q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
59
  for i in range(1, len(tks_w)):
60
  q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
61
  if not q:
 
137
  from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
138
  import numpy as np
139
  sims = CosineSimilarity([avec], bvecs)
140
+ tksim = self.token_similarity(atks, btkss)
141
+ return np.array(sims[0]) * vtweight + \
142
+ np.array(tksim) * tkweight, tksim, sims[0]
143
 
144
+ def token_similarity(self, atks, btkss):
145
  def toDict(tks):
146
  d = {}
147
  if isinstance(tks, str):
 
154
 
155
  atks = toDict(atks)
156
  btkss = [toDict(tks) for tks in btkss]
157
+ return [self.similarity(atks, btks) for btks in btkss]
 
 
158
 
159
  def similarity(self, qtwt, dtwt):
160
  if isinstance(dtwt, type("")):
rag/nlp/rag_tokenizer.py CHANGED
@@ -241,11 +241,14 @@ class RagTokenizer:
241
 
242
  return self.score_(res[::-1])
243
 
 
 
 
244
  def tokenize(self, line):
245
  line = self._strQ2B(line).lower()
246
  line = self._tradi2simp(line)
247
  zh_num = len([1 for c in line if is_chinese(c)])
248
- if zh_num < len(line) * 0.2:
249
  return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
250
 
251
  arr = re.split(self.SPLIT_CHAR, line)
@@ -293,7 +296,7 @@ class RagTokenizer:
293
 
294
  i = e + 1
295
 
296
- res = " ".join(res)
297
  if self.DEBUG:
298
  print("[TKS]", self.merge_(res))
299
  return self.merge_(res)
@@ -336,7 +339,7 @@ class RagTokenizer:
336
 
337
  res.append(stk)
338
 
339
- return " ".join(res)
340
 
341
 
342
  def is_chinese(s):
 
241
 
242
  return self.score_(res[::-1])
243
 
244
+ def english_normalize_(self, tks):
245
+ return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
246
+
247
  def tokenize(self, line):
248
  line = self._strQ2B(line).lower()
249
  line = self._tradi2simp(line)
250
  zh_num = len([1 for c in line if is_chinese(c)])
251
+ if zh_num == 0:
252
  return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
253
 
254
  arr = re.split(self.SPLIT_CHAR, line)
 
296
 
297
  i = e + 1
298
 
299
+ res = " ".join(self.english_normalize_(res))
300
  if self.DEBUG:
301
  print("[TKS]", self.merge_(res))
302
  return self.merge_(res)
 
339
 
340
  res.append(stk)
341
 
342
+ return " ".join(self.english_normalize_(res))
343
 
344
 
345
  def is_chinese(s):
rag/nlp/search.py CHANGED
@@ -71,8 +71,8 @@ class Dealer:
71
 
72
  s = Search()
73
  pg = int(req.get("page", 1)) - 1
74
- ps = int(req.get("size", 1000))
75
  topk = int(req.get("topk", 1024))
 
76
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
77
  "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
78
  "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
@@ -311,6 +311,26 @@ class Dealer:
311
  ins_tw, tkweight, vtweight)
312
  return sim, tksim, vtsim
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
315
  return self.qryr.hybrid_similarity(ans_embd,
316
  ins_embd,
@@ -318,17 +338,22 @@ class Dealer:
318
  rag_tokenizer.tokenize(inst).split(" "))
319
 
320
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
321
- vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
322
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
323
  if not question:
324
  return ranks
325
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
326
  "question": question, "vector": True, "topk": top,
327
- "similarity": similarity_threshold}
 
328
  sres = self.search(req, index_name(tenant_id), embd_mdl)
329
 
330
- sim, tsim, vsim = self.rerank(
331
- sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
 
 
 
 
332
  idx = np.argsort(sim * -1)
333
 
334
  dim = len(sres.query_vector)
 
71
 
72
  s = Search()
73
  pg = int(req.get("page", 1)) - 1
 
74
  topk = int(req.get("topk", 1024))
75
+ ps = int(req.get("size", topk))
76
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
77
  "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
78
  "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
 
311
  ins_tw, tkweight, vtweight)
312
  return sim, tksim, vtsim
313
 
314
+ def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
315
+ vtweight=0.7, cfield="content_ltks"):
316
+ _, keywords = self.qryr.question(query)
317
+
318
+ for i in sres.ids:
319
+ if isinstance(sres.field[i].get("important_kwd", []), str):
320
+ sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
321
+ ins_tw = []
322
+ for i in sres.ids:
323
+ content_ltks = sres.field[i][cfield].split(" ")
324
+ title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
325
+ important_kwd = sres.field[i].get("important_kwd", [])
326
+ tks = content_ltks + title_tks + important_kwd
327
+ ins_tw.append(tks)
328
+
329
+ tksim = self.qryr.token_similarity(keywords, ins_tw)
330
+ vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw])
331
+
332
+ return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
333
+
334
  def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
335
  return self.qryr.hybrid_similarity(ans_embd,
336
  ins_embd,
 
338
  rag_tokenizer.tokenize(inst).split(" "))
339
 
340
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
341
+ vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
342
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
343
  if not question:
344
  return ranks
345
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
346
  "question": question, "vector": True, "topk": top,
347
+ "similarity": similarity_threshold,
348
+ "available_int": 1}
349
  sres = self.search(req, index_name(tenant_id), embd_mdl)
350
 
351
+ if rerank_mdl:
352
+ sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
353
+ sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
354
+ else:
355
+ sim, tsim, vsim = self.rerank(
356
+ sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
357
  idx = np.argsort(sim * -1)
358
 
359
  dim = len(sres.query_vector)