Kevin Hu
commited on
Commit
·
1e02591
1
Parent(s):
c50cfc1
Fix @ in model name issue. (#3821)
Browse files### What problem does this PR solve?
#3814
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
api/db/services/dialog_service.py
CHANGED
@@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000):
|
|
120 |
|
121 |
|
122 |
def llm_id2llm_type(llm_id):
|
123 |
-
llm_id =
|
124 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
125 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
126 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
@@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
|
|
132 |
def chat(dialog, messages, stream=True, **kwargs):
|
133 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
134 |
st = timer()
|
135 |
-
|
136 |
-
fid = None
|
137 |
-
llm_id = tmp[0]
|
138 |
-
if len(tmp)>1: fid = tmp[1]
|
139 |
-
|
140 |
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
141 |
if not llm:
|
142 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
|
|
120 |
|
121 |
|
122 |
def llm_id2llm_type(llm_id):
|
123 |
+
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
124 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
125 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
126 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
|
132 |
def chat(dialog, messages, stream=True, **kwargs):
|
133 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
134 |
st = timer()
|
135 |
+
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
|
|
|
|
|
|
|
|
136 |
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
137 |
if not llm:
|
138 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
api/db/services/llm_service.py
CHANGED
@@ -13,8 +13,12 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import logging
|
|
|
|
|
17 |
from api.db.services.user_service import TenantService
|
|
|
18 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
19 |
from api.db import LLMType
|
20 |
from api.db.db_models import DB
|
@@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
|
|
36 |
@classmethod
|
37 |
@DB.connection_context()
|
38 |
def get_api_key(cls, tenant_id, model_name):
|
39 |
-
|
40 |
-
if
|
41 |
-
objs = cls.query(tenant_id=tenant_id, llm_name=
|
42 |
else:
|
43 |
-
objs = cls.query(tenant_id=tenant_id, llm_name=
|
44 |
if not objs:
|
45 |
return
|
46 |
return objs[0]
|
@@ -61,6 +65,23 @@ class TenantLLMService(CommonService):
|
|
61 |
|
62 |
return list(objs)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@classmethod
|
65 |
@DB.connection_context()
|
66 |
def model_instance(cls, tenant_id, llm_type,
|
@@ -85,9 +106,7 @@ class TenantLLMService(CommonService):
|
|
85 |
assert False, "LLM type error"
|
86 |
|
87 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
88 |
-
|
89 |
-
fid = None if len(tmp) < 2 else tmp[1]
|
90 |
-
mdlnm = tmp[0]
|
91 |
if model_config: model_config = model_config.to_dict()
|
92 |
if not model_config:
|
93 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
@@ -168,7 +187,7 @@ class TenantLLMService(CommonService):
|
|
168 |
else:
|
169 |
assert False, "LLM type error"
|
170 |
|
171 |
-
llm_name =
|
172 |
|
173 |
num = 0
|
174 |
try:
|
@@ -179,7 +198,7 @@ class TenantLLMService(CommonService):
|
|
179 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
180 |
.execute()
|
181 |
else:
|
182 |
-
|
183 |
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
184 |
except Exception:
|
185 |
logging.exception("TenantLLMService.increase_usage got exception")
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
import json
|
17 |
import logging
|
18 |
+
import os
|
19 |
+
|
20 |
from api.db.services.user_service import TenantService
|
21 |
+
from api.utils.file_utils import get_project_base_directory
|
22 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
23 |
from api.db import LLMType
|
24 |
from api.db.db_models import DB
|
|
|
40 |
@classmethod
|
41 |
@DB.connection_context()
|
42 |
def get_api_key(cls, tenant_id, model_name):
|
43 |
+
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
44 |
+
if not fid:
|
45 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
46 |
else:
|
47 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
48 |
if not objs:
|
49 |
return
|
50 |
return objs[0]
|
|
|
65 |
|
66 |
return list(objs)
|
67 |
|
68 |
+
@staticmethod
|
69 |
+
def split_model_name_and_factory(model_name):
|
70 |
+
arr = model_name.split("@")
|
71 |
+
if len(arr) < 2:
|
72 |
+
return model_name, None
|
73 |
+
if len(arr) > 2:
|
74 |
+
return "@".join(arr[0:-1]), arr[-1]
|
75 |
+
try:
|
76 |
+
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
77 |
+
fact = set([f["name"] for f in fact])
|
78 |
+
if arr[-1] not in fact:
|
79 |
+
return model_name, None
|
80 |
+
return arr[0], arr[-1]
|
81 |
+
except Exception as e:
|
82 |
+
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
83 |
+
return model_name, None
|
84 |
+
|
85 |
@classmethod
|
86 |
@DB.connection_context()
|
87 |
def model_instance(cls, tenant_id, llm_type,
|
|
|
106 |
assert False, "LLM type error"
|
107 |
|
108 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
109 |
+
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
|
|
|
110 |
if model_config: model_config = model_config.to_dict()
|
111 |
if not model_config:
|
112 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
|
|
187 |
else:
|
188 |
assert False, "LLM type error"
|
189 |
|
190 |
+
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
191 |
|
192 |
num = 0
|
193 |
try:
|
|
|
198 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
199 |
.execute()
|
200 |
else:
|
201 |
+
if not llm_factory: llm_factory = mdlnm
|
202 |
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
203 |
except Exception:
|
204 |
logging.exception("TenantLLMService.increase_usage got exception")
|