Kevin Hu commited on
Commit
bf00d96
·
1 Parent(s): 8b574ab

fix duplicated llm name betweeen different suppliers (#2477)

Browse files

### What problem does this PR solve?

#2465

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

api/apps/chunk_app.py CHANGED
@@ -27,7 +27,7 @@ from rag.utils.es_conn import ELASTICSEARCH
27
  from rag.utils import rmSpace
28
  from api.db import LLMType, ParserType
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
- from api.db.services.llm_service import TenantLLMService
31
  from api.db.services.user_service import UserTenantService
32
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
33
  from api.db.services.document_service import DocumentService
@@ -141,8 +141,7 @@ def set():
141
  return get_data_error_result(retmsg="Tenant not found!")
142
 
143
  embd_id = DocumentService.get_embd_id(req["doc_id"])
144
- embd_mdl = TenantLLMService.model_instance(
145
- tenant_id, LLMType.EMBEDDING.value, embd_id)
146
 
147
  e, doc = DocumentService.get_by_id(req["doc_id"])
148
  if not e:
@@ -235,8 +234,7 @@ def create():
235
  return get_data_error_result(retmsg="Tenant not found!")
236
 
237
  embd_id = DocumentService.get_embd_id(req["doc_id"])
238
- embd_mdl = TenantLLMService.model_instance(
239
- tenant_id, LLMType.EMBEDDING.value, embd_id)
240
 
241
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
242
  v = 0.1 * v[0] + 0.9 * v[1]
@@ -281,16 +279,14 @@ def retrieval_test():
281
  if not e:
282
  return get_data_error_result(retmsg="Knowledgebase not found!")
283
 
284
- embd_mdl = TenantLLMService.model_instance(
285
- kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
286
 
287
  rerank_mdl = None
288
  if req.get("rerank_id"):
289
- rerank_mdl = TenantLLMService.model_instance(
290
- kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
291
 
292
  if req.get("keyword", False):
293
- chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
294
  question += keyword_extraction(chat_mdl, question)
295
 
296
  retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
 
27
  from rag.utils import rmSpace
28
  from api.db import LLMType, ParserType
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
+ from api.db.services.llm_service import LLMBundle
31
  from api.db.services.user_service import UserTenantService
32
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
33
  from api.db.services.document_service import DocumentService
 
141
  return get_data_error_result(retmsg="Tenant not found!")
142
 
143
  embd_id = DocumentService.get_embd_id(req["doc_id"])
144
+ embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
 
145
 
146
  e, doc = DocumentService.get_by_id(req["doc_id"])
147
  if not e:
 
234
  return get_data_error_result(retmsg="Tenant not found!")
235
 
236
  embd_id = DocumentService.get_embd_id(req["doc_id"])
237
+ embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
 
238
 
239
  v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
240
  v = 0.1 * v[0] + 0.9 * v[1]
 
279
  if not e:
280
  return get_data_error_result(retmsg="Knowledgebase not found!")
281
 
282
+ embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
 
283
 
284
  rerank_mdl = None
285
  if req.get("rerank_id"):
286
+ rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
 
287
 
288
  if req.get("keyword", False):
289
+ chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
290
  question += keyword_extraction(chat_mdl, question)
291
 
292
  retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
api/db/services/dialog_service.py CHANGED
@@ -78,6 +78,7 @@ def message_fit_in(msg, max_length=4000):
78
 
79
 
80
  def llm_id2llm_type(llm_id):
 
81
  fnm = os.path.join(get_project_base_directory(), "conf")
82
  llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
83
  for llm_factory in llm_factories["factory_llm_infos"]:
@@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id):
89
  def chat(dialog, messages, stream=True, **kwargs):
90
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
91
  st = timer()
92
- llm = LLMService.query(llm_name=dialog.llm_id)
 
 
 
 
 
93
  if not llm:
94
- llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
 
95
  if not llm:
96
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
97
  max_tokens = 8192
 
78
 
79
 
80
  def llm_id2llm_type(llm_id):
81
+ llm_id = llm_id.split("@")[0]
82
  fnm = os.path.join(get_project_base_directory(), "conf")
83
  llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
84
  for llm_factory in llm_factories["factory_llm_infos"]:
 
90
  def chat(dialog, messages, stream=True, **kwargs):
91
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
92
  st = timer()
93
+ tmp = dialog.llm_id.split("@")
94
+ fid = None
95
+ llm_id = tmp[0]
96
+ if len(tmp)>1: fid = tmp[1]
97
+
98
+ llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
99
  if not llm:
100
+ llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
101
+ TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
102
  if not llm:
103
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
104
  max_tokens = 8192
api/db/services/llm_service.py CHANGED
@@ -17,7 +17,7 @@ from api.db.services.user_service import TenantService
17
  from api.settings import database_logger
18
  from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
19
  from api.db import LLMType
20
- from api.db.db_models import DB, UserTenant
21
  from api.db.db_models import LLMFactories, LLM, TenantLLM
22
  from api.db.services.common_service import CommonService
23
 
@@ -36,7 +36,11 @@ class TenantLLMService(CommonService):
36
  @classmethod
37
  @DB.connection_context()
38
  def get_api_key(cls, tenant_id, model_name):
39
- objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
 
 
 
 
40
  if not objs:
41
  return
42
  return objs[0]
@@ -81,14 +85,17 @@ class TenantLLMService(CommonService):
81
  assert False, "LLM type error"
82
 
83
  model_config = cls.get_api_key(tenant_id, mdlnm)
 
 
 
84
  if model_config: model_config = model_config.to_dict()
85
  if not model_config:
86
  if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
87
- llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm)
88
  if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
89
- model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""}
90
  if not model_config:
91
- if llm_name == "flag-embedding":
92
  model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
93
  "llm_name": llm_name, "api_base": ""}
94
  else:
 
17
  from api.settings import database_logger
18
  from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
19
  from api.db import LLMType
20
+ from api.db.db_models import DB
21
  from api.db.db_models import LLMFactories, LLM, TenantLLM
22
  from api.db.services.common_service import CommonService
23
 
 
36
  @classmethod
37
  @DB.connection_context()
38
  def get_api_key(cls, tenant_id, model_name):
39
+ arr = model_name.split("@")
40
+ if len(arr) < 2:
41
+ objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
42
+ else:
43
+ objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
44
  if not objs:
45
  return
46
  return objs[0]
 
85
  assert False, "LLM type error"
86
 
87
  model_config = cls.get_api_key(tenant_id, mdlnm)
88
+ tmp = mdlnm.split("@")
89
+ fid = None if len(tmp) < 2 else tmp[1]
90
+ mdlnm = tmp[0]
91
  if model_config: model_config = model_config.to_dict()
92
  if not model_config:
93
  if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
94
+ llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
95
  if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
96
+ model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
97
  if not model_config:
98
+ if mdlnm == "flag-embedding":
99
  model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
100
  "llm_name": llm_name, "api_base": ""}
101
  else:
rag/app/naive.py CHANGED
@@ -76,7 +76,7 @@ class Docx(DocxParser):
76
  if last_image:
77
  image_list.insert(0, last_image)
78
  last_image = None
79
- lines.append((self.__clean(p.text), image_list, p.style.name))
80
  else:
81
  if current_image := self.get_picture(self.doc, p):
82
  if lines:
 
76
  if last_image:
77
  image_list.insert(0, last_image)
78
  last_image = None
79
+ lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
80
  else:
81
  if current_image := self.get_picture(self.doc, p):
82
  if lines: