File size: 7,361 Bytes
3079197 484e5ab 3079197 484e5ab e32ef75 9bf75d4 3079197 3198faf 484e5ab e32ef75 3198faf 6be3dd5 e32ef75 c127ae4 484e5ab c1bdfb8 6be3dd5 79ada0b e32ef75 484e5ab e32ef75 ba51460 e32ef75 484e5ab ba51460 e32ef75 ba51460 3069c36 ba51460 801a3c1 ba51460 484e5ab e32ef75 0c30cc9 484e5ab e32ef75 e06e08c e32ef75 0c30cc9 e32ef75 2587709 e32ef75 41c7a59 e32ef75 79ada0b e32ef75 886ae57 79ada0b e32ef75 886ae57 79ada0b e32ef75 79ada0b e32ef75 79ada0b e32ef75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.services.user_service import TenantService
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService
class LLMFactoriesService(CommonService):
model = LLMFactories
class LLMService(CommonService):
model = LLM
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [
cls.model.llm_factory,
LLMFactories.logo,
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name,
cls.model.used_tokens
]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["Youdao", "FastEmbed"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config:
if llm_name == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
else:
raise LookupError("Model({}) not authorized".format(mdlnm))
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"]
)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
num = 0
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute()
return num
class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
return txt
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error(
"Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt
|