Kevin Hu commited on
Commit
1e02591
·
1 Parent(s): c50cfc1

Fix @ in model name issue. (#3821)

Browse files

### What problem does this PR solve?

#3814

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

api/db/services/dialog_service.py CHANGED
@@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000):
120
 
121
 
122
  def llm_id2llm_type(llm_id):
123
- llm_id = llm_id.split("@")[0]
124
  fnm = os.path.join(get_project_base_directory(), "conf")
125
  llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
126
  for llm_factory in llm_factories["factory_llm_infos"]:
@@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
132
  def chat(dialog, messages, stream=True, **kwargs):
133
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
134
  st = timer()
135
- tmp = dialog.llm_id.split("@")
136
- fid = None
137
- llm_id = tmp[0]
138
- if len(tmp)>1: fid = tmp[1]
139
-
140
  llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
141
  if not llm:
142
  llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
 
120
 
121
 
122
  def llm_id2llm_type(llm_id):
123
+ llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
124
  fnm = os.path.join(get_project_base_directory(), "conf")
125
  llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
126
  for llm_factory in llm_factories["factory_llm_infos"]:
 
132
  def chat(dialog, messages, stream=True, **kwargs):
133
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
134
  st = timer()
135
+ llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
 
 
 
 
136
  llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
137
  if not llm:
138
  llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
api/db/services/llm_service.py CHANGED
@@ -13,8 +13,12 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import logging
 
 
17
  from api.db.services.user_service import TenantService
 
18
  from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
19
  from api.db import LLMType
20
  from api.db.db_models import DB
@@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
36
  @classmethod
37
  @DB.connection_context()
38
  def get_api_key(cls, tenant_id, model_name):
39
- arr = model_name.split("@")
40
- if len(arr) < 2:
41
- objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
42
  else:
43
- objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
44
  if not objs:
45
  return
46
  return objs[0]
@@ -61,6 +65,23 @@ class TenantLLMService(CommonService):
61
 
62
  return list(objs)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @classmethod
65
  @DB.connection_context()
66
  def model_instance(cls, tenant_id, llm_type,
@@ -85,9 +106,7 @@ class TenantLLMService(CommonService):
85
  assert False, "LLM type error"
86
 
87
  model_config = cls.get_api_key(tenant_id, mdlnm)
88
- tmp = mdlnm.split("@")
89
- fid = None if len(tmp) < 2 else tmp[1]
90
- mdlnm = tmp[0]
91
  if model_config: model_config = model_config.to_dict()
92
  if not model_config:
93
  if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
@@ -168,7 +187,7 @@ class TenantLLMService(CommonService):
168
  else:
169
  assert False, "LLM type error"
170
 
171
- llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
172
 
173
  num = 0
174
  try:
@@ -179,7 +198,7 @@ class TenantLLMService(CommonService):
179
  .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
180
  .execute()
181
  else:
182
- llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm
183
  num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
184
  except Exception:
185
  logging.exception("TenantLLMService.increase_usage got exception")
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import json
17
  import logging
18
+ import os
19
+
20
  from api.db.services.user_service import TenantService
21
+ from api.utils.file_utils import get_project_base_directory
22
  from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
23
  from api.db import LLMType
24
  from api.db.db_models import DB
 
40
  @classmethod
41
  @DB.connection_context()
42
  def get_api_key(cls, tenant_id, model_name):
43
+ mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
44
+ if not fid:
45
+ objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
46
  else:
47
+ objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
48
  if not objs:
49
  return
50
  return objs[0]
 
65
 
66
  return list(objs)
67
 
68
+ @staticmethod
69
+ def split_model_name_and_factory(model_name):
70
+ arr = model_name.split("@")
71
+ if len(arr) < 2:
72
+ return model_name, None
73
+ if len(arr) > 2:
74
+ return "@".join(arr[0:-1]), arr[-1]
75
+ try:
76
+ fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
77
+ fact = set([f["name"] for f in fact])
78
+ if arr[-1] not in fact:
79
+ return model_name, None
80
+ return arr[0], arr[-1]
81
+ except Exception as e:
82
+ logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
83
+ return model_name, None
84
+
85
  @classmethod
86
  @DB.connection_context()
87
  def model_instance(cls, tenant_id, llm_type,
 
106
  assert False, "LLM type error"
107
 
108
  model_config = cls.get_api_key(tenant_id, mdlnm)
109
+ mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
 
 
110
  if model_config: model_config = model_config.to_dict()
111
  if not model_config:
112
  if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
 
187
  else:
188
  assert False, "LLM type error"
189
 
190
+ llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
191
 
192
  num = 0
193
  try:
 
198
  .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
199
  .execute()
200
  else:
201
+ if not llm_factory: llm_factory = mdlnm
202
  num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
203
  except Exception:
204
  logging.exception("TenantLLMService.increase_usage got exception")