KevinHuSh
commited on
Commit
·
c037a22
1
Parent(s):
459ac83
add rerank model (#969)
Browse files### What problem does this PR solve?
feat: add rerank models to the project #724 #162
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/chunk_app.py +9 -2
- api/apps/dialog_app.py +5 -0
- api/apps/llm_app.py +13 -3
- api/apps/user_app.py +5 -3
- api/db/__init__.py +1 -0
- api/db/db_models.py +33 -6
- api/db/init_data.py +97 -1
- api/db/services/dialog_service.py +5 -2
- api/db/services/llm_service.py +31 -7
- api/db/services/user_service.py +1 -0
- api/settings.py +15 -1
- rag/llm/__init__.py +10 -3
- rag/llm/embedding_model.py +56 -24
- rag/llm/rerank_model.py +113 -0
- rag/nlp/query.py +7 -4
- rag/nlp/rag_tokenizer.py +6 -3
- rag/nlp/search.py +30 -5
api/apps/chunk_app.py
CHANGED
@@ -257,8 +257,15 @@ def retrieval_test():
|
|
257 |
|
258 |
embd_mdl = TenantLLMService.model_instance(
|
259 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
for c in ranks["chunks"]:
|
263 |
if "vector" in c:
|
264 |
del c["vector"]
|
|
|
257 |
|
258 |
embd_mdl = TenantLLMService.model_instance(
|
259 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
260 |
+
|
261 |
+
rerank_mdl = None
|
262 |
+
if req.get("rerank_id"):
|
263 |
+
rerank_mdl = TenantLLMService.model_instance(
|
264 |
+
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
265 |
+
|
266 |
+
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
267 |
+
similarity_threshold, vector_similarity_weight, top,
|
268 |
+
doc_ids, rerank_mdl=rerank_mdl)
|
269 |
for c in ranks["chunks"]:
|
270 |
if "vector" in c:
|
271 |
del c["vector"]
|
api/apps/dialog_app.py
CHANGED
@@ -33,6 +33,9 @@ def set_dialog():
|
|
33 |
name = req.get("name", "New Dialog")
|
34 |
description = req.get("description", "A helpful Dialog")
|
35 |
top_n = req.get("top_n", 6)
|
|
|
|
|
|
|
36 |
similarity_threshold = req.get("similarity_threshold", 0.1)
|
37 |
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
38 |
llm_setting = req.get("llm_setting", {})
|
@@ -83,6 +86,8 @@ def set_dialog():
|
|
83 |
"llm_setting": llm_setting,
|
84 |
"prompt_config": prompt_config,
|
85 |
"top_n": top_n,
|
|
|
|
|
86 |
"similarity_threshold": similarity_threshold,
|
87 |
"vector_similarity_weight": vector_similarity_weight
|
88 |
}
|
|
|
33 |
name = req.get("name", "New Dialog")
|
34 |
description = req.get("description", "A helpful Dialog")
|
35 |
top_n = req.get("top_n", 6)
|
36 |
+
top_k = req.get("top_k", 1024)
|
37 |
+
rerank_id = req.get("rerank_id", "")
|
38 |
+
if not rerank_id: req["rerank_id"] = ""
|
39 |
similarity_threshold = req.get("similarity_threshold", 0.1)
|
40 |
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
41 |
llm_setting = req.get("llm_setting", {})
|
|
|
86 |
"llm_setting": llm_setting,
|
87 |
"prompt_config": prompt_config,
|
88 |
"top_n": top_n,
|
89 |
+
"top_k": top_k,
|
90 |
+
"rerank_id": rerank_id,
|
91 |
"similarity_threshold": similarity_threshold,
|
92 |
"vector_similarity_weight": vector_similarity_weight
|
93 |
}
|
api/apps/llm_app.py
CHANGED
@@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|
20 |
from api.db import StatusEnum, LLMType
|
21 |
from api.db.db_models import TenantLLM
|
22 |
from api.utils.api_utils import get_json_result
|
23 |
-
from rag.llm import EmbeddingModel, ChatModel
|
24 |
|
25 |
|
26 |
@manager.route('/factories', methods=['GET'])
|
@@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
|
28 |
def factories():
|
29 |
try:
|
30 |
fac = LLMFactoriesService.get_all()
|
31 |
-
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
|
32 |
except Exception as e:
|
33 |
return server_error_response(e)
|
34 |
|
@@ -64,6 +64,16 @@ def set_api_key():
|
|
64 |
except Exception as e:
|
65 |
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
66 |
e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
if msg:
|
69 |
return get_data_error_result(retmsg=msg)
|
@@ -199,7 +209,7 @@ def list_app():
|
|
199 |
llms = [m.to_dict()
|
200 |
for m in llms if m.status == StatusEnum.VALID.value]
|
201 |
for m in llms:
|
202 |
-
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
|
203 |
|
204 |
llm_set = set([m["llm_name"] for m in llms])
|
205 |
for o in objs:
|
|
|
20 |
from api.db import StatusEnum, LLMType
|
21 |
from api.db.db_models import TenantLLM
|
22 |
from api.utils.api_utils import get_json_result
|
23 |
+
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
24 |
|
25 |
|
26 |
@manager.route('/factories', methods=['GET'])
|
|
|
28 |
def factories():
|
29 |
try:
|
30 |
fac = LLMFactoriesService.get_all()
|
31 |
+
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
|
32 |
except Exception as e:
|
33 |
return server_error_response(e)
|
34 |
|
|
|
64 |
except Exception as e:
|
65 |
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
66 |
e)
|
67 |
+
elif llm.model_type == LLMType.RERANK:
|
68 |
+
mdl = RerankModel[factory](
|
69 |
+
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
70 |
+
try:
|
71 |
+
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
72 |
+
if len(arr[0]) == 0 or tc == 0:
|
73 |
+
raise Exception("Fail")
|
74 |
+
except Exception as e:
|
75 |
+
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
76 |
+
e)
|
77 |
|
78 |
if msg:
|
79 |
return get_data_error_result(retmsg=msg)
|
|
|
209 |
llms = [m.to_dict()
|
210 |
for m in llms if m.status == StatusEnum.VALID.value]
|
211 |
for m in llms:
|
212 |
+
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
|
213 |
|
214 |
llm_set = set([m["llm_name"] for m in llms])
|
215 |
for o in objs:
|
api/apps/user_app.py
CHANGED
@@ -26,8 +26,9 @@ from api.db.services.llm_service import TenantLLMService, LLMService
|
|
26 |
from api.utils.api_utils import server_error_response, validate_request
|
27 |
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
|
28 |
from api.db import UserTenantRole, LLMType, FileType
|
29 |
-
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS,
|
30 |
-
|
|
|
31 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
32 |
from api.db.services.file_service import FileService
|
33 |
from api.settings import stat_logger
|
@@ -288,7 +289,8 @@ def user_register(user_id, user):
|
|
288 |
"embd_id": EMBEDDING_MDL,
|
289 |
"asr_id": ASR_MDL,
|
290 |
"parser_ids": PARSERS,
|
291 |
-
"img2txt_id": IMAGE2TEXT_MDL
|
|
|
292 |
}
|
293 |
usr_tenant = {
|
294 |
"tenant_id": user_id,
|
|
|
26 |
from api.utils.api_utils import server_error_response, validate_request
|
27 |
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
|
28 |
from api.db import UserTenantRole, LLMType, FileType
|
29 |
+
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
|
30 |
+
API_KEY, \
|
31 |
+
LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
|
32 |
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
33 |
from api.db.services.file_service import FileService
|
34 |
from api.settings import stat_logger
|
|
|
289 |
"embd_id": EMBEDDING_MDL,
|
290 |
"asr_id": ASR_MDL,
|
291 |
"parser_ids": PARSERS,
|
292 |
+
"img2txt_id": IMAGE2TEXT_MDL,
|
293 |
+
"rerank_id": RERANK_MDL
|
294 |
}
|
295 |
usr_tenant = {
|
296 |
"tenant_id": user_id,
|
api/db/__init__.py
CHANGED
@@ -54,6 +54,7 @@ class LLMType(StrEnum):
|
|
54 |
EMBEDDING = 'embedding'
|
55 |
SPEECH2TEXT = 'speech2text'
|
56 |
IMAGE2TEXT = 'image2text'
|
|
|
57 |
|
58 |
|
59 |
class ChatStyle(StrEnum):
|
|
|
54 |
EMBEDDING = 'embedding'
|
55 |
SPEECH2TEXT = 'speech2text'
|
56 |
IMAGE2TEXT = 'image2text'
|
57 |
+
RERANK = 'rerank'
|
58 |
|
59 |
|
60 |
class ChatStyle(StrEnum):
|
api/db/db_models.py
CHANGED
@@ -437,6 +437,10 @@ class Tenant(DataBaseModel):
|
|
437 |
max_length=128,
|
438 |
null=False,
|
439 |
help_text="default image to text model ID")
|
|
|
|
|
|
|
|
|
440 |
parser_ids = CharField(
|
441 |
max_length=256,
|
442 |
null=False,
|
@@ -771,11 +775,16 @@ class Dialog(DataBaseModel):
|
|
771 |
similarity_threshold = FloatField(default=0.2)
|
772 |
vector_similarity_weight = FloatField(default=0.3)
|
773 |
top_n = IntegerField(default=6)
|
|
|
774 |
do_refer = CharField(
|
775 |
max_length=1,
|
776 |
null=False,
|
777 |
help_text="it needs to insert reference index into answer or not",
|
778 |
default="1")
|
|
|
|
|
|
|
|
|
779 |
|
780 |
kb_ids = JSONField(null=False, default=[])
|
781 |
status = CharField(
|
@@ -825,11 +834,29 @@ class API4Conversation(DataBaseModel):
|
|
825 |
|
826 |
|
827 |
def migrate_db():
|
828 |
-
try:
|
829 |
with DB.transaction():
|
830 |
migrator = MySQLMigrator(DB)
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
max_length=128,
|
438 |
null=False,
|
439 |
help_text="default image to text model ID")
|
440 |
+
rerank_id = CharField(
|
441 |
+
max_length=128,
|
442 |
+
null=False,
|
443 |
+
help_text="default rerank model ID")
|
444 |
parser_ids = CharField(
|
445 |
max_length=256,
|
446 |
null=False,
|
|
|
775 |
similarity_threshold = FloatField(default=0.2)
|
776 |
vector_similarity_weight = FloatField(default=0.3)
|
777 |
top_n = IntegerField(default=6)
|
778 |
+
top_k = IntegerField(default=1024)
|
779 |
do_refer = CharField(
|
780 |
max_length=1,
|
781 |
null=False,
|
782 |
help_text="it needs to insert reference index into answer or not",
|
783 |
default="1")
|
784 |
+
rerank_id = CharField(
|
785 |
+
max_length=128,
|
786 |
+
null=False,
|
787 |
+
help_text="default rerank model ID")
|
788 |
|
789 |
kb_ids = JSONField(null=False, default=[])
|
790 |
status = CharField(
|
|
|
834 |
|
835 |
|
836 |
def migrate_db():
|
|
|
837 |
with DB.transaction():
|
838 |
migrator = MySQLMigrator(DB)
|
839 |
+
try:
|
840 |
+
migrate(
|
841 |
+
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
842 |
+
)
|
843 |
+
except Exception as e:
|
844 |
+
pass
|
845 |
+
try:
|
846 |
+
migrate(
|
847 |
+
migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
|
848 |
+
)
|
849 |
+
except Exception as e:
|
850 |
+
pass
|
851 |
+
try:
|
852 |
+
migrate(
|
853 |
+
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
|
854 |
+
)
|
855 |
+
except Exception as e:
|
856 |
+
pass
|
857 |
+
try:
|
858 |
+
migrate(
|
859 |
+
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
860 |
+
)
|
861 |
+
except Exception as e:
|
862 |
+
pass
|
api/db/init_data.py
CHANGED
@@ -142,7 +142,17 @@ factory_infos = [{
|
|
142 |
"logo": "",
|
143 |
"tags": "LLM,TEXT EMBEDDING",
|
144 |
"status": "1",
|
145 |
-
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# {
|
147 |
# "name": "文心一言",
|
148 |
# "logo": "",
|
@@ -367,6 +377,13 @@ def init_llm_factory():
|
|
367 |
"max_tokens": 512,
|
368 |
"model_type": LLMType.EMBEDDING.value
|
369 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
# ------------------------ DeepSeek -----------------------
|
371 |
{
|
372 |
"fid": factory_infos[8]["name"],
|
@@ -440,6 +457,85 @@ def init_llm_factory():
|
|
440 |
"max_tokens": 512,
|
441 |
"model_type": LLMType.EMBEDDING.value
|
442 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
]
|
444 |
for info in factory_infos:
|
445 |
try:
|
|
|
142 |
"logo": "",
|
143 |
"tags": "LLM,TEXT EMBEDDING",
|
144 |
"status": "1",
|
145 |
+
},{
|
146 |
+
"name": "Jina",
|
147 |
+
"logo": "",
|
148 |
+
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
149 |
+
"status": "1",
|
150 |
+
},{
|
151 |
+
"name": "BAAI",
|
152 |
+
"logo": "",
|
153 |
+
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
154 |
+
"status": "1",
|
155 |
+
}
|
156 |
# {
|
157 |
# "name": "文心一言",
|
158 |
# "logo": "",
|
|
|
377 |
"max_tokens": 512,
|
378 |
"model_type": LLMType.EMBEDDING.value
|
379 |
},
|
380 |
+
{
|
381 |
+
"fid": factory_infos[7]["name"],
|
382 |
+
"llm_name": "maidalun1020/bce-reranker-base_v1",
|
383 |
+
"tags": "RE-RANK, 8K",
|
384 |
+
"max_tokens": 8196,
|
385 |
+
"model_type": LLMType.RERANK.value
|
386 |
+
},
|
387 |
# ------------------------ DeepSeek -----------------------
|
388 |
{
|
389 |
"fid": factory_infos[8]["name"],
|
|
|
457 |
"max_tokens": 512,
|
458 |
"model_type": LLMType.EMBEDDING.value
|
459 |
},
|
460 |
+
# ------------------------ Jina -----------------------
|
461 |
+
{
|
462 |
+
"fid": factory_infos[11]["name"],
|
463 |
+
"llm_name": "jina-reranker-v1-base-en",
|
464 |
+
"tags": "RE-RANK,8k",
|
465 |
+
"max_tokens": 8196,
|
466 |
+
"model_type": LLMType.RERANK.value
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"fid": factory_infos[11]["name"],
|
470 |
+
"llm_name": "jina-reranker-v1-turbo-en",
|
471 |
+
"tags": "RE-RANK,8k",
|
472 |
+
"max_tokens": 8196,
|
473 |
+
"model_type": LLMType.RERANK.value
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"fid": factory_infos[11]["name"],
|
477 |
+
"llm_name": "jina-reranker-v1-tiny-en",
|
478 |
+
"tags": "RE-RANK,8k",
|
479 |
+
"max_tokens": 8196,
|
480 |
+
"model_type": LLMType.RERANK.value
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"fid": factory_infos[11]["name"],
|
484 |
+
"llm_name": "jina-colbert-v1-en",
|
485 |
+
"tags": "RE-RANK,8k",
|
486 |
+
"max_tokens": 8196,
|
487 |
+
"model_type": LLMType.RERANK.value
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"fid": factory_infos[11]["name"],
|
491 |
+
"llm_name": "jina-embeddings-v2-base-en",
|
492 |
+
"tags": "TEXT EMBEDDING",
|
493 |
+
"max_tokens": 8196,
|
494 |
+
"model_type": LLMType.EMBEDDING.value
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"fid": factory_infos[11]["name"],
|
498 |
+
"llm_name": "jina-embeddings-v2-base-de",
|
499 |
+
"tags": "TEXT EMBEDDING",
|
500 |
+
"max_tokens": 8196,
|
501 |
+
"model_type": LLMType.EMBEDDING.value
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"fid": factory_infos[11]["name"],
|
505 |
+
"llm_name": "jina-embeddings-v2-base-es",
|
506 |
+
"tags": "TEXT EMBEDDING",
|
507 |
+
"max_tokens": 8196,
|
508 |
+
"model_type": LLMType.EMBEDDING.value
|
509 |
+
},
|
510 |
+
{
|
511 |
+
"fid": factory_infos[11]["name"],
|
512 |
+
"llm_name": "jina-embeddings-v2-base-code",
|
513 |
+
"tags": "TEXT EMBEDDING",
|
514 |
+
"max_tokens": 8196,
|
515 |
+
"model_type": LLMType.EMBEDDING.value
|
516 |
+
},
|
517 |
+
{
|
518 |
+
"fid": factory_infos[11]["name"],
|
519 |
+
"llm_name": "jina-embeddings-v2-base-zh",
|
520 |
+
"tags": "TEXT EMBEDDING",
|
521 |
+
"max_tokens": 8196,
|
522 |
+
"model_type": LLMType.EMBEDDING.value
|
523 |
+
},
|
524 |
+
# ------------------------ BAAI -----------------------
|
525 |
+
{
|
526 |
+
"fid": factory_infos[12]["name"],
|
527 |
+
"llm_name": "BAAI/bge-large-zh-v1.5",
|
528 |
+
"tags": "TEXT EMBEDDING,",
|
529 |
+
"max_tokens": 1024,
|
530 |
+
"model_type": LLMType.EMBEDDING.value
|
531 |
+
},
|
532 |
+
{
|
533 |
+
"fid": factory_infos[12]["name"],
|
534 |
+
"llm_name": "BAAI/bge-reranker-v2-m3",
|
535 |
+
"tags": "LLM,CHAT,",
|
536 |
+
"max_tokens": 16385,
|
537 |
+
"model_type": LLMType.RERANK.value
|
538 |
+
},
|
539 |
]
|
540 |
for info in factory_infos:
|
541 |
try:
|
api/db/services/dialog_service.py
CHANGED
@@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
115 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
116 |
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
117 |
else:
|
|
|
|
|
|
|
118 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
119 |
dialog.similarity_threshold,
|
120 |
dialog.vector_similarity_weight,
|
121 |
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
122 |
-
top=1024, aggs=False)
|
123 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
124 |
chat_logger.info(
|
125 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
@@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
130 |
|
131 |
kwargs["knowledge"] = "\n".join(knowledges)
|
132 |
gen_conf = dialog.llm_setting
|
133 |
-
|
134 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
135 |
msg.extend([{"role": m["role"], "content": m["content"]}
|
136 |
for m in messages if m["role"] != "system"])
|
|
|
115 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
116 |
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
117 |
else:
|
118 |
+
rerank_mdl = None
|
119 |
+
if dialog.rerank_id:
|
120 |
+
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
121 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
122 |
dialog.similarity_threshold,
|
123 |
dialog.vector_similarity_weight,
|
124 |
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
125 |
+
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
126 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
127 |
chat_logger.info(
|
128 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
|
|
133 |
|
134 |
kwargs["knowledge"] = "\n".join(knowledges)
|
135 |
gen_conf = dialog.llm_setting
|
136 |
+
|
137 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
138 |
msg.extend([{"role": m["role"], "content": m["content"]}
|
139 |
for m in messages if m["role"] != "system"])
|
api/db/services/llm_service.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
#
|
16 |
from api.db.services.user_service import TenantService
|
17 |
from api.settings import database_logger
|
18 |
-
from rag.llm import EmbeddingModel, CvModel, ChatModel
|
19 |
from api.db import LLMType
|
20 |
from api.db.db_models import DB, UserTenant
|
21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
@@ -73,21 +73,25 @@ class TenantLLMService(CommonService):
|
|
73 |
mdlnm = tenant.img2txt_id
|
74 |
elif llm_type == LLMType.CHAT.value:
|
75 |
mdlnm = tenant.llm_id if not llm_name else llm_name
|
|
|
|
|
76 |
else:
|
77 |
assert False, "LLM type error"
|
78 |
|
79 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
80 |
if model_config: model_config = model_config.to_dict()
|
81 |
if not model_config:
|
82 |
-
if llm_type
|
83 |
llm = LLMService.query(llm_name=llm_name)
|
84 |
-
if llm and llm[0].fid in ["Youdao", "FastEmbed", "
|
85 |
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
86 |
if not model_config:
|
87 |
if llm_name == "flag-embedding":
|
88 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
89 |
"llm_name": llm_name, "api_base": ""}
|
90 |
else:
|
|
|
|
|
91 |
raise LookupError("Model({}) not authorized".format(mdlnm))
|
92 |
|
93 |
if llm_type == LLMType.EMBEDDING.value:
|
@@ -96,6 +100,12 @@ class TenantLLMService(CommonService):
|
|
96 |
return EmbeddingModel[model_config["llm_factory"]](
|
97 |
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
if llm_type == LLMType.IMAGE2TEXT.value:
|
100 |
if model_config["llm_factory"] not in CvModel:
|
101 |
return
|
@@ -125,14 +135,20 @@ class TenantLLMService(CommonService):
|
|
125 |
mdlnm = tenant.img2txt_id
|
126 |
elif llm_type == LLMType.CHAT.value:
|
127 |
mdlnm = tenant.llm_id if not llm_name else llm_name
|
|
|
|
|
128 |
else:
|
129 |
assert False, "LLM type error"
|
130 |
|
131 |
num = 0
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
return num
|
137 |
|
138 |
@classmethod
|
@@ -176,6 +192,14 @@ class LLMBundle(object):
|
|
176 |
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
177 |
return emd, used_tokens
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
def describe(self, image, max_tokens=300):
|
180 |
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
181 |
if not TenantLLMService.increase_usage(
|
|
|
15 |
#
|
16 |
from api.db.services.user_service import TenantService
|
17 |
from api.settings import database_logger
|
18 |
+
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel
|
19 |
from api.db import LLMType
|
20 |
from api.db.db_models import DB, UserTenant
|
21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
|
|
73 |
mdlnm = tenant.img2txt_id
|
74 |
elif llm_type == LLMType.CHAT.value:
|
75 |
mdlnm = tenant.llm_id if not llm_name else llm_name
|
76 |
+
elif llm_type == LLMType.RERANK:
|
77 |
+
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
78 |
else:
|
79 |
assert False, "LLM type error"
|
80 |
|
81 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
82 |
if model_config: model_config = model_config.to_dict()
|
83 |
if not model_config:
|
84 |
+
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
85 |
llm = LLMService.query(llm_name=llm_name)
|
86 |
+
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
87 |
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
88 |
if not model_config:
|
89 |
if llm_name == "flag-embedding":
|
90 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
91 |
"llm_name": llm_name, "api_base": ""}
|
92 |
else:
|
93 |
+
if not mdlnm:
|
94 |
+
raise LookupError(f"Type of {llm_type} model is not set.")
|
95 |
raise LookupError("Model({}) not authorized".format(mdlnm))
|
96 |
|
97 |
if llm_type == LLMType.EMBEDDING.value:
|
|
|
100 |
return EmbeddingModel[model_config["llm_factory"]](
|
101 |
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
102 |
|
103 |
+
if llm_type == LLMType.RERANK:
|
104 |
+
if model_config["llm_factory"] not in RerankModel:
|
105 |
+
return
|
106 |
+
return RerankModel[model_config["llm_factory"]](
|
107 |
+
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
108 |
+
|
109 |
if llm_type == LLMType.IMAGE2TEXT.value:
|
110 |
if model_config["llm_factory"] not in CvModel:
|
111 |
return
|
|
|
135 |
mdlnm = tenant.img2txt_id
|
136 |
elif llm_type == LLMType.CHAT.value:
|
137 |
mdlnm = tenant.llm_id if not llm_name else llm_name
|
138 |
+
elif llm_type == LLMType.RERANK:
|
139 |
+
mdlnm = tenant.llm_id if not llm_name else llm_name
|
140 |
else:
|
141 |
assert False, "LLM type error"
|
142 |
|
143 |
num = 0
|
144 |
+
try:
|
145 |
+
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
|
146 |
+
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
|
147 |
+
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
148 |
+
.execute()
|
149 |
+
except Exception as e:
|
150 |
+
print(e)
|
151 |
+
pass
|
152 |
return num
|
153 |
|
154 |
@classmethod
|
|
|
192 |
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
193 |
return emd, used_tokens
|
194 |
|
195 |
+
def similarity(self, query: str, texts: list):
|
196 |
+
sim, used_tokens = self.mdl.similarity(query, texts)
|
197 |
+
if not TenantLLMService.increase_usage(
|
198 |
+
self.tenant_id, self.llm_type, used_tokens):
|
199 |
+
database_logger.error(
|
200 |
+
"Can't update token usage for {}/RERANK".format(self.tenant_id))
|
201 |
+
return sim, used_tokens
|
202 |
+
|
203 |
def describe(self, image, max_tokens=300):
|
204 |
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
205 |
if not TenantLLMService.increase_usage(
|
api/db/services/user_service.py
CHANGED
@@ -93,6 +93,7 @@ class TenantService(CommonService):
|
|
93 |
cls.model.name,
|
94 |
cls.model.llm_id,
|
95 |
cls.model.embd_id,
|
|
|
96 |
cls.model.asr_id,
|
97 |
cls.model.img2txt_id,
|
98 |
cls.model.parser_ids,
|
|
|
93 |
cls.model.name,
|
94 |
cls.model.llm_id,
|
95 |
cls.model.embd_id,
|
96 |
+
cls.model.rerank_id,
|
97 |
cls.model.asr_id,
|
98 |
cls.model.img2txt_id,
|
99 |
cls.model.parser_ids,
|
api/settings.py
CHANGED
@@ -89,9 +89,22 @@ default_llm = {
|
|
89 |
},
|
90 |
"DeepSeek": {
|
91 |
"chat_model": "deepseek-chat",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
93 |
"image2text_model": "",
|
94 |
"asr_model": "",
|
|
|
95 |
}
|
96 |
}
|
97 |
LLM = get_base_config("user_default_llm", {})
|
@@ -104,7 +117,8 @@ if LLM_FACTORY not in default_llm:
|
|
104 |
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
105 |
LLM_FACTORY = "Tongyi-Qianwen"
|
106 |
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
107 |
-
EMBEDDING_MDL = default_llm[
|
|
|
108 |
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
109 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
110 |
|
|
|
89 |
},
|
90 |
"DeepSeek": {
|
91 |
"chat_model": "deepseek-chat",
|
92 |
+
"embedding_model": "",
|
93 |
+
"image2text_model": "",
|
94 |
+
"asr_model": "",
|
95 |
+
},
|
96 |
+
"VolcEngine": {
|
97 |
+
"chat_model": "",
|
98 |
+
"embedding_model": "",
|
99 |
+
"image2text_model": "",
|
100 |
+
"asr_model": "",
|
101 |
+
},
|
102 |
+
"BAAI": {
|
103 |
+
"chat_model": "",
|
104 |
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
105 |
"image2text_model": "",
|
106 |
"asr_model": "",
|
107 |
+
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
108 |
}
|
109 |
}
|
110 |
LLM = get_base_config("user_default_llm", {})
|
|
|
117 |
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
118 |
LLM_FACTORY = "Tongyi-Qianwen"
|
119 |
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
120 |
+
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
|
121 |
+
RERANK_MDL = default_llm["BAAI"]["rerank_model"]
|
122 |
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
123 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
124 |
|
rag/llm/__init__.py
CHANGED
@@ -16,18 +16,19 @@
|
|
16 |
from .embedding_model import *
|
17 |
from .chat_model import *
|
18 |
from .cv_model import *
|
|
|
19 |
|
20 |
|
21 |
EmbeddingModel = {
|
22 |
"Ollama": OllamaEmbed,
|
23 |
"OpenAI": OpenAIEmbed,
|
24 |
"Xinference": XinferenceEmbed,
|
25 |
-
"Tongyi-Qianwen": DefaultEmbedding
|
26 |
"ZHIPU-AI": ZhipuEmbed,
|
27 |
"FastEmbed": FastEmbed,
|
28 |
"Youdao": YoudaoEmbed,
|
29 |
-
"
|
30 |
-
"
|
31 |
}
|
32 |
|
33 |
|
@@ -52,3 +53,9 @@ ChatModel = {
|
|
52 |
"BaiChuan": BaiChuanChat
|
53 |
}
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from .embedding_model import *
|
17 |
from .chat_model import *
|
18 |
from .cv_model import *
|
19 |
+
from .rerank_model import *
|
20 |
|
21 |
|
22 |
EmbeddingModel = {
|
23 |
"Ollama": OllamaEmbed,
|
24 |
"OpenAI": OpenAIEmbed,
|
25 |
"Xinference": XinferenceEmbed,
|
26 |
+
"Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed,
|
27 |
"ZHIPU-AI": ZhipuEmbed,
|
28 |
"FastEmbed": FastEmbed,
|
29 |
"Youdao": YoudaoEmbed,
|
30 |
+
"BaiChuan": BaiChuanEmbed,
|
31 |
+
"BAAI": DefaultEmbedding
|
32 |
}
|
33 |
|
34 |
|
|
|
53 |
"BaiChuan": BaiChuanChat
|
54 |
}
|
55 |
|
56 |
+
|
57 |
+
RerankModel = {
|
58 |
+
"BAAI": DefaultRerank,
|
59 |
+
"Jina": JinaRerank,
|
60 |
+
"Youdao": YoudaoRerank,
|
61 |
+
}
|
rag/llm/embedding_model.py
CHANGED
@@ -13,8 +13,10 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
from typing import Optional
|
17 |
|
|
|
18 |
from huggingface_hub import snapshot_download
|
19 |
from zhipuai import ZhipuAI
|
20 |
import os
|
@@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel
|
|
26 |
import torch
|
27 |
import numpy as np
|
28 |
|
29 |
-
from api.utils.file_utils import
|
30 |
from rag.utils import num_tokens_from_string, truncate
|
31 |
|
32 |
-
try:
|
33 |
-
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
34 |
-
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
35 |
-
use_fp16=torch.cuda.is_available())
|
36 |
-
except Exception as e:
|
37 |
-
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
38 |
-
local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
39 |
-
local_dir_use_symlinks=False)
|
40 |
-
flag_model = FlagModel(model_dir,
|
41 |
-
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
42 |
-
use_fp16=torch.cuda.is_available())
|
43 |
-
|
44 |
|
45 |
class Base(ABC):
|
46 |
def __init__(self, key, model_name):
|
@@ -54,7 +44,9 @@ class Base(ABC):
|
|
54 |
|
55 |
|
56 |
class DefaultEmbedding(Base):
|
57 |
-
|
|
|
|
|
58 |
"""
|
59 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
60 |
|
@@ -66,7 +58,18 @@ class DefaultEmbedding(Base):
|
|
66 |
^_-
|
67 |
|
68 |
"""
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def encode(self, texts: list, batch_size=32):
|
72 |
texts = [truncate(t, 2048) for t in texts]
|
@@ -75,12 +78,12 @@ class DefaultEmbedding(Base):
|
|
75 |
token_count += num_tokens_from_string(t)
|
76 |
res = []
|
77 |
for i in range(0, len(texts), batch_size):
|
78 |
-
res.extend(self.
|
79 |
return np.array(res), token_count
|
80 |
|
81 |
def encode_queries(self, text: str):
|
82 |
token_count = num_tokens_from_string(text)
|
83 |
-
return self.
|
84 |
|
85 |
|
86 |
class OpenAIEmbed(Base):
|
@@ -189,16 +192,19 @@ class OllamaEmbed(Base):
|
|
189 |
|
190 |
|
191 |
class FastEmbed(Base):
|
|
|
|
|
192 |
def __init__(
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
):
|
200 |
from fastembed import TextEmbedding
|
201 |
-
|
|
|
202 |
|
203 |
def encode(self, texts: list, batch_size=32):
|
204 |
# Using the internal tokenizer to encode the texts and get the total
|
@@ -265,3 +271,29 @@ class YoudaoEmbed(Base):
|
|
265 |
def encode_queries(self, text):
|
266 |
embds = YoudaoEmbed._client.encode([text])
|
267 |
return np.array(embds[0]), num_tokens_from_string(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
import re
|
17 |
from typing import Optional
|
18 |
|
19 |
+
import requests
|
20 |
from huggingface_hub import snapshot_download
|
21 |
from zhipuai import ZhipuAI
|
22 |
import os
|
|
|
28 |
import torch
|
29 |
import numpy as np
|
30 |
|
31 |
+
from api.utils.file_utils import get_home_cache_dir
|
32 |
from rag.utils import num_tokens_from_string, truncate
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
class Base(ABC):
|
36 |
def __init__(self, key, model_name):
|
|
|
44 |
|
45 |
|
46 |
class DefaultEmbedding(Base):
|
47 |
+
_model = None
|
48 |
+
|
49 |
+
def __init__(self, key, model_name, **kwargs):
|
50 |
"""
|
51 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
52 |
|
|
|
58 |
^_-
|
59 |
|
60 |
"""
|
61 |
+
if not DefaultEmbedding._model:
|
62 |
+
try:
|
63 |
+
self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
64 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
65 |
+
use_fp16=torch.cuda.is_available())
|
66 |
+
except Exception as e:
|
67 |
+
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
68 |
+
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
69 |
+
local_dir_use_symlinks=False)
|
70 |
+
self._model = FlagModel(model_dir,
|
71 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
72 |
+
use_fp16=torch.cuda.is_available())
|
73 |
|
74 |
def encode(self, texts: list, batch_size=32):
|
75 |
texts = [truncate(t, 2048) for t in texts]
|
|
|
78 |
token_count += num_tokens_from_string(t)
|
79 |
res = []
|
80 |
for i in range(0, len(texts), batch_size):
|
81 |
+
res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
82 |
return np.array(res), token_count
|
83 |
|
84 |
def encode_queries(self, text: str):
|
85 |
token_count = num_tokens_from_string(text)
|
86 |
+
return self._model.encode_queries([text]).tolist()[0], token_count
|
87 |
|
88 |
|
89 |
class OpenAIEmbed(Base):
|
|
|
192 |
|
193 |
|
194 |
class FastEmbed(Base):
|
195 |
+
_model = None
|
196 |
+
|
197 |
def __init__(
|
198 |
+
self,
|
199 |
+
key: Optional[str] = None,
|
200 |
+
model_name: str = "BAAI/bge-small-en-v1.5",
|
201 |
+
cache_dir: Optional[str] = None,
|
202 |
+
threads: Optional[int] = None,
|
203 |
+
**kwargs,
|
204 |
):
|
205 |
from fastembed import TextEmbedding
|
206 |
+
if not FastEmbed._model:
|
207 |
+
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
208 |
|
209 |
def encode(self, texts: list, batch_size=32):
|
210 |
# Using the internal tokenizer to encode the texts and get the total
|
|
|
271 |
def encode_queries(self, text):
|
272 |
embds = YoudaoEmbed._client.encode([text])
|
273 |
return np.array(embds[0]), num_tokens_from_string(text)
|
274 |
+
|
275 |
+
|
276 |
+
class JinaEmbed(Base):
|
277 |
+
def __init__(self, key, model_name="jina-embeddings-v2-base-zh",
|
278 |
+
base_url="https://api.jina.ai/v1/embeddings"):
|
279 |
+
|
280 |
+
self.base_url = "https://api.jina.ai/v1/embeddings"
|
281 |
+
self.headers = {
|
282 |
+
"Content-Type": "application/json",
|
283 |
+
"Authorization": f"Bearer {key}"
|
284 |
+
}
|
285 |
+
self.model_name = model_name
|
286 |
+
|
287 |
+
def encode(self, texts: list, batch_size=None):
|
288 |
+
texts = [truncate(t, 8196) for t in texts]
|
289 |
+
data = {
|
290 |
+
"model": self.model_name,
|
291 |
+
"input": texts,
|
292 |
+
'encoding_type': 'float'
|
293 |
+
}
|
294 |
+
res = requests.post(self.base_url, headers=self.headers, json=data)
|
295 |
+
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
296 |
+
|
297 |
+
def encode_queries(self, text):
|
298 |
+
embds, cnt = self.encode([text])
|
299 |
+
return np.array(embds[0]), cnt
|
rag/llm/rerank_model.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
import re
|
17 |
+
import requests
|
18 |
+
import torch
|
19 |
+
from FlagEmbedding import FlagReranker
|
20 |
+
from huggingface_hub import snapshot_download
|
21 |
+
import os
|
22 |
+
from abc import ABC
|
23 |
+
import numpy as np
|
24 |
+
from api.utils.file_utils import get_home_cache_dir
|
25 |
+
from rag.utils import num_tokens_from_string, truncate
|
26 |
+
|
27 |
+
|
28 |
+
class Base(ABC):
|
29 |
+
def __init__(self, key, model_name):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def similarity(self, query: str, texts: list):
|
33 |
+
raise NotImplementedError("Please implement encode method!")
|
34 |
+
|
35 |
+
|
36 |
+
class DefaultRerank(Base):
|
37 |
+
_model = None
|
38 |
+
|
39 |
+
def __init__(self, key, model_name, **kwargs):
|
40 |
+
"""
|
41 |
+
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
42 |
+
|
43 |
+
For Linux:
|
44 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
45 |
+
|
46 |
+
For Windows:
|
47 |
+
Good luck
|
48 |
+
^_-
|
49 |
+
|
50 |
+
"""
|
51 |
+
if not DefaultRerank._model:
|
52 |
+
try:
|
53 |
+
self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
54 |
+
use_fp16=torch.cuda.is_available())
|
55 |
+
except Exception as e:
|
56 |
+
self._model = snapshot_download(repo_id=model_name,
|
57 |
+
local_dir=os.path.join(get_home_cache_dir(),
|
58 |
+
re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
59 |
+
local_dir_use_symlinks=False)
|
60 |
+
self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
|
61 |
+
use_fp16=torch.cuda.is_available())
|
62 |
+
|
63 |
+
def similarity(self, query: str, texts: list):
|
64 |
+
pairs = [(query,truncate(t, 2048)) for t in texts]
|
65 |
+
token_count = 0
|
66 |
+
for _, t in pairs:
|
67 |
+
token_count += num_tokens_from_string(t)
|
68 |
+
batch_size = 32
|
69 |
+
res = []
|
70 |
+
for i in range(0, len(pairs), batch_size):
|
71 |
+
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
72 |
+
res.extend(scores)
|
73 |
+
return np.array(res), token_count
|
74 |
+
|
75 |
+
|
76 |
+
class JinaRerank(Base):
|
77 |
+
def __init__(self, key, model_name="jina-reranker-v1-base-en",
|
78 |
+
base_url="https://api.jina.ai/v1/rerank"):
|
79 |
+
self.base_url = "https://api.jina.ai/v1/rerank"
|
80 |
+
self.headers = {
|
81 |
+
"Content-Type": "application/json",
|
82 |
+
"Authorization": f"Bearer {key}"
|
83 |
+
}
|
84 |
+
self.model_name = model_name
|
85 |
+
|
86 |
+
def similarity(self, query: str, texts: list):
|
87 |
+
texts = [truncate(t, 8196) for t in texts]
|
88 |
+
data = {
|
89 |
+
"model": self.model_name,
|
90 |
+
"query": query,
|
91 |
+
"documents": texts,
|
92 |
+
"top_n": len(texts)
|
93 |
+
}
|
94 |
+
res = requests.post(self.base_url, headers=self.headers, json=data)
|
95 |
+
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
|
96 |
+
|
97 |
+
|
98 |
+
class YoudaoRerank(DefaultRerank):
|
99 |
+
_model = None
|
100 |
+
|
101 |
+
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
102 |
+
from BCEmbedding import RerankerModel
|
103 |
+
if not YoudaoRerank._model:
|
104 |
+
try:
|
105 |
+
print("LOADING BCE...")
|
106 |
+
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
107 |
+
get_home_cache_dir(),
|
108 |
+
re.sub(r"^[a-zA-Z]+/", "", model_name)))
|
109 |
+
except Exception as e:
|
110 |
+
YoudaoRerank._model = RerankerModel(
|
111 |
+
model_name_or_path=model_name.replace(
|
112 |
+
"maidalun1020", "InfiniFlow"))
|
113 |
+
|
rag/nlp/query.py
CHANGED
@@ -54,7 +54,8 @@ class EsQueryer:
|
|
54 |
if not self.isChinese(txt):
|
55 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
56 |
tks_w = self.tw.weights(tks)
|
57 |
-
|
|
|
58 |
for i in range(1, len(tks_w)):
|
59 |
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
60 |
if not q:
|
@@ -136,7 +137,11 @@ class EsQueryer:
|
|
136 |
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
137 |
import numpy as np
|
138 |
sims = CosineSimilarity([avec], bvecs)
|
|
|
|
|
|
|
139 |
|
|
|
140 |
def toDict(tks):
|
141 |
d = {}
|
142 |
if isinstance(tks, str):
|
@@ -149,9 +154,7 @@ class EsQueryer:
|
|
149 |
|
150 |
atks = toDict(atks)
|
151 |
btkss = [toDict(tks) for tks in btkss]
|
152 |
-
|
153 |
-
return np.array(sims[0]) * vtweight + \
|
154 |
-
np.array(tksim) * tkweight, tksim, sims[0]
|
155 |
|
156 |
def similarity(self, qtwt, dtwt):
|
157 |
if isinstance(dtwt, type("")):
|
|
|
54 |
if not self.isChinese(txt):
|
55 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
56 |
tks_w = self.tw.weights(tks)
|
57 |
+
tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
|
58 |
+
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
59 |
for i in range(1, len(tks_w)):
|
60 |
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
61 |
if not q:
|
|
|
137 |
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
138 |
import numpy as np
|
139 |
sims = CosineSimilarity([avec], bvecs)
|
140 |
+
tksim = self.token_similarity(atks, btkss)
|
141 |
+
return np.array(sims[0]) * vtweight + \
|
142 |
+
np.array(tksim) * tkweight, tksim, sims[0]
|
143 |
|
144 |
+
def token_similarity(self, atks, btkss):
|
145 |
def toDict(tks):
|
146 |
d = {}
|
147 |
if isinstance(tks, str):
|
|
|
154 |
|
155 |
atks = toDict(atks)
|
156 |
btkss = [toDict(tks) for tks in btkss]
|
157 |
+
return [self.similarity(atks, btks) for btks in btkss]
|
|
|
|
|
158 |
|
159 |
def similarity(self, qtwt, dtwt):
|
160 |
if isinstance(dtwt, type("")):
|
rag/nlp/rag_tokenizer.py
CHANGED
@@ -241,11 +241,14 @@ class RagTokenizer:
|
|
241 |
|
242 |
return self.score_(res[::-1])
|
243 |
|
|
|
|
|
|
|
244 |
def tokenize(self, line):
|
245 |
line = self._strQ2B(line).lower()
|
246 |
line = self._tradi2simp(line)
|
247 |
zh_num = len([1 for c in line if is_chinese(c)])
|
248 |
-
if zh_num
|
249 |
return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
|
250 |
|
251 |
arr = re.split(self.SPLIT_CHAR, line)
|
@@ -293,7 +296,7 @@ class RagTokenizer:
|
|
293 |
|
294 |
i = e + 1
|
295 |
|
296 |
-
res = " ".join(res)
|
297 |
if self.DEBUG:
|
298 |
print("[TKS]", self.merge_(res))
|
299 |
return self.merge_(res)
|
@@ -336,7 +339,7 @@ class RagTokenizer:
|
|
336 |
|
337 |
res.append(stk)
|
338 |
|
339 |
-
return " ".join(res)
|
340 |
|
341 |
|
342 |
def is_chinese(s):
|
|
|
241 |
|
242 |
return self.score_(res[::-1])
|
243 |
|
244 |
+
def english_normalize_(self, tks):
|
245 |
+
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
246 |
+
|
247 |
def tokenize(self, line):
|
248 |
line = self._strQ2B(line).lower()
|
249 |
line = self._tradi2simp(line)
|
250 |
zh_num = len([1 for c in line if is_chinese(c)])
|
251 |
+
if zh_num == 0:
|
252 |
return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
|
253 |
|
254 |
arr = re.split(self.SPLIT_CHAR, line)
|
|
|
296 |
|
297 |
i = e + 1
|
298 |
|
299 |
+
res = " ".join(self.english_normalize_(res))
|
300 |
if self.DEBUG:
|
301 |
print("[TKS]", self.merge_(res))
|
302 |
return self.merge_(res)
|
|
|
339 |
|
340 |
res.append(stk)
|
341 |
|
342 |
+
return " ".join(self.english_normalize_(res))
|
343 |
|
344 |
|
345 |
def is_chinese(s):
|
rag/nlp/search.py
CHANGED
@@ -71,8 +71,8 @@ class Dealer:
|
|
71 |
|
72 |
s = Search()
|
73 |
pg = int(req.get("page", 1)) - 1
|
74 |
-
ps = int(req.get("size", 1000))
|
75 |
topk = int(req.get("topk", 1024))
|
|
|
76 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
77 |
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
|
78 |
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
@@ -311,6 +311,26 @@ class Dealer:
|
|
311 |
ins_tw, tkweight, vtweight)
|
312 |
return sim, tksim, vtsim
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
315 |
return self.qryr.hybrid_similarity(ans_embd,
|
316 |
ins_embd,
|
@@ -318,17 +338,22 @@ class Dealer:
|
|
318 |
rag_tokenizer.tokenize(inst).split(" "))
|
319 |
|
320 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
321 |
-
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
322 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
323 |
if not question:
|
324 |
return ranks
|
325 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
|
326 |
"question": question, "vector": True, "topk": top,
|
327 |
-
"similarity": similarity_threshold
|
|
|
328 |
sres = self.search(req, index_name(tenant_id), embd_mdl)
|
329 |
|
330 |
-
|
331 |
-
|
|
|
|
|
|
|
|
|
332 |
idx = np.argsort(sim * -1)
|
333 |
|
334 |
dim = len(sres.query_vector)
|
|
|
71 |
|
72 |
s = Search()
|
73 |
pg = int(req.get("page", 1)) - 1
|
|
|
74 |
topk = int(req.get("topk", 1024))
|
75 |
+
ps = int(req.get("size", topk))
|
76 |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
77 |
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
|
78 |
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
|
|
311 |
ins_tw, tkweight, vtweight)
|
312 |
return sim, tksim, vtsim
|
313 |
|
314 |
+
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
315 |
+
vtweight=0.7, cfield="content_ltks"):
|
316 |
+
_, keywords = self.qryr.question(query)
|
317 |
+
|
318 |
+
for i in sres.ids:
|
319 |
+
if isinstance(sres.field[i].get("important_kwd", []), str):
|
320 |
+
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
|
321 |
+
ins_tw = []
|
322 |
+
for i in sres.ids:
|
323 |
+
content_ltks = sres.field[i][cfield].split(" ")
|
324 |
+
title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
|
325 |
+
important_kwd = sres.field[i].get("important_kwd", [])
|
326 |
+
tks = content_ltks + title_tks + important_kwd
|
327 |
+
ins_tw.append(tks)
|
328 |
+
|
329 |
+
tksim = self.qryr.token_similarity(keywords, ins_tw)
|
330 |
+
vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw])
|
331 |
+
|
332 |
+
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
|
333 |
+
|
334 |
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
335 |
return self.qryr.hybrid_similarity(ans_embd,
|
336 |
ins_embd,
|
|
|
338 |
rag_tokenizer.tokenize(inst).split(" "))
|
339 |
|
340 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
341 |
+
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
|
342 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
343 |
if not question:
|
344 |
return ranks
|
345 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
|
346 |
"question": question, "vector": True, "topk": top,
|
347 |
+
"similarity": similarity_threshold,
|
348 |
+
"available_int": 1}
|
349 |
sres = self.search(req, index_name(tenant_id), embd_mdl)
|
350 |
|
351 |
+
if rerank_mdl:
|
352 |
+
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
353 |
+
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
354 |
+
else:
|
355 |
+
sim, tsim, vsim = self.rerank(
|
356 |
+
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
357 |
idx = np.argsort(sim * -1)
|
358 |
|
359 |
dim = len(sres.query_vector)
|