Kevin Hu
commited on
Commit
·
bf00d96
1
Parent(s):
8b574ab
fix duplicated llm name betweeen different suppliers (#2477)
Browse files### What problem does this PR solve?
#2465
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/chunk_app.py +6 -10
- api/db/services/dialog_service.py +9 -2
- api/db/services/llm_service.py +12 -5
- rag/app/naive.py +1 -1
api/apps/chunk_app.py
CHANGED
@@ -27,7 +27,7 @@ from rag.utils.es_conn import ELASTICSEARCH
|
|
27 |
from rag.utils import rmSpace
|
28 |
from api.db import LLMType, ParserType
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
-
from api.db.services.llm_service import
|
31 |
from api.db.services.user_service import UserTenantService
|
32 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
33 |
from api.db.services.document_service import DocumentService
|
@@ -141,8 +141,7 @@ def set():
|
|
141 |
return get_data_error_result(retmsg="Tenant not found!")
|
142 |
|
143 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
144 |
-
embd_mdl =
|
145 |
-
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
146 |
|
147 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
148 |
if not e:
|
@@ -235,8 +234,7 @@ def create():
|
|
235 |
return get_data_error_result(retmsg="Tenant not found!")
|
236 |
|
237 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
238 |
-
embd_mdl =
|
239 |
-
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
240 |
|
241 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
242 |
v = 0.1 * v[0] + 0.9 * v[1]
|
@@ -281,16 +279,14 @@ def retrieval_test():
|
|
281 |
if not e:
|
282 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
283 |
|
284 |
-
embd_mdl =
|
285 |
-
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
286 |
|
287 |
rerank_mdl = None
|
288 |
if req.get("rerank_id"):
|
289 |
-
rerank_mdl =
|
290 |
-
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
291 |
|
292 |
if req.get("keyword", False):
|
293 |
-
chat_mdl =
|
294 |
question += keyword_extraction(chat_mdl, question)
|
295 |
|
296 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
|
|
27 |
from rag.utils import rmSpace
|
28 |
from api.db import LLMType, ParserType
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
+
from api.db.services.llm_service import LLMBundle
|
31 |
from api.db.services.user_service import UserTenantService
|
32 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
33 |
from api.db.services.document_service import DocumentService
|
|
|
141 |
return get_data_error_result(retmsg="Tenant not found!")
|
142 |
|
143 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
144 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
|
|
145 |
|
146 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
147 |
if not e:
|
|
|
234 |
return get_data_error_result(retmsg="Tenant not found!")
|
235 |
|
236 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
237 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
|
238 |
|
239 |
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
240 |
v = 0.1 * v[0] + 0.9 * v[1]
|
|
|
279 |
if not e:
|
280 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
281 |
|
282 |
+
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
|
|
283 |
|
284 |
rerank_mdl = None
|
285 |
if req.get("rerank_id"):
|
286 |
+
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
|
|
287 |
|
288 |
if req.get("keyword", False):
|
289 |
+
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
290 |
question += keyword_extraction(chat_mdl, question)
|
291 |
|
292 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
api/db/services/dialog_service.py
CHANGED
@@ -78,6 +78,7 @@ def message_fit_in(msg, max_length=4000):
|
|
78 |
|
79 |
|
80 |
def llm_id2llm_type(llm_id):
|
|
|
81 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
82 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
83 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
@@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id):
|
|
89 |
def chat(dialog, messages, stream=True, **kwargs):
|
90 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
91 |
st = timer()
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
93 |
if not llm:
|
94 |
-
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=
|
|
|
95 |
if not llm:
|
96 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
97 |
max_tokens = 8192
|
|
|
78 |
|
79 |
|
80 |
def llm_id2llm_type(llm_id):
|
81 |
+
llm_id = llm_id.split("@")[0]
|
82 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
83 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
84 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
|
90 |
def chat(dialog, messages, stream=True, **kwargs):
|
91 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
92 |
st = timer()
|
93 |
+
tmp = dialog.llm_id.split("@")
|
94 |
+
fid = None
|
95 |
+
llm_id = tmp[0]
|
96 |
+
if len(tmp)>1: fid = tmp[1]
|
97 |
+
|
98 |
+
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
99 |
if not llm:
|
100 |
+
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
101 |
+
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
|
102 |
if not llm:
|
103 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
104 |
max_tokens = 8192
|
api/db/services/llm_service.py
CHANGED
@@ -17,7 +17,7 @@ from api.db.services.user_service import TenantService
|
|
17 |
from api.settings import database_logger
|
18 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
19 |
from api.db import LLMType
|
20 |
-
from api.db.db_models import DB
|
21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
22 |
from api.db.services.common_service import CommonService
|
23 |
|
@@ -36,7 +36,11 @@ class TenantLLMService(CommonService):
|
|
36 |
@classmethod
|
37 |
@DB.connection_context()
|
38 |
def get_api_key(cls, tenant_id, model_name):
|
39 |
-
|
|
|
|
|
|
|
|
|
40 |
if not objs:
|
41 |
return
|
42 |
return objs[0]
|
@@ -81,14 +85,17 @@ class TenantLLMService(CommonService):
|
|
81 |
assert False, "LLM type error"
|
82 |
|
83 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
|
|
|
|
|
84 |
if model_config: model_config = model_config.to_dict()
|
85 |
if not model_config:
|
86 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
87 |
-
llm = LLMService.query(llm_name=
|
88 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
89 |
-
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name":
|
90 |
if not model_config:
|
91 |
-
if
|
92 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
93 |
"llm_name": llm_name, "api_base": ""}
|
94 |
else:
|
|
|
17 |
from api.settings import database_logger
|
18 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
19 |
from api.db import LLMType
|
20 |
+
from api.db.db_models import DB
|
21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
22 |
from api.db.services.common_service import CommonService
|
23 |
|
|
|
36 |
@classmethod
|
37 |
@DB.connection_context()
|
38 |
def get_api_key(cls, tenant_id, model_name):
|
39 |
+
arr = model_name.split("@")
|
40 |
+
if len(arr) < 2:
|
41 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
42 |
+
else:
|
43 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
|
44 |
if not objs:
|
45 |
return
|
46 |
return objs[0]
|
|
|
85 |
assert False, "LLM type error"
|
86 |
|
87 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
88 |
+
tmp = mdlnm.split("@")
|
89 |
+
fid = None if len(tmp) < 2 else tmp[1]
|
90 |
+
mdlnm = tmp[0]
|
91 |
if model_config: model_config = model_config.to_dict()
|
92 |
if not model_config:
|
93 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
94 |
+
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
95 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
96 |
+
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
97 |
if not model_config:
|
98 |
+
if mdlnm == "flag-embedding":
|
99 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
100 |
"llm_name": llm_name, "api_base": ""}
|
101 |
else:
|
rag/app/naive.py
CHANGED
@@ -76,7 +76,7 @@ class Docx(DocxParser):
|
|
76 |
if last_image:
|
77 |
image_list.insert(0, last_image)
|
78 |
last_image = None
|
79 |
-
lines.append((self.__clean(p.text), image_list, p.style.name))
|
80 |
else:
|
81 |
if current_image := self.get_picture(self.doc, p):
|
82 |
if lines:
|
|
|
76 |
if last_image:
|
77 |
image_list.insert(0, last_image)
|
78 |
last_image = None
|
79 |
+
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
|
80 |
else:
|
81 |
if current_image := self.get_picture(self.doc, p):
|
82 |
if lines:
|