KevinHuSh
commited on
Commit
·
ba51460
1
Parent(s):
67dea7a
Add bce-embedding and fastembed (#383)
Browse files### What problem does this PR solve?
Issue link:#326
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- README.md +2 -0
- README_ja.md +2 -0
- README_zh.md +2 -0
- api/apps/chunk_app.py +1 -1
- api/apps/document_app.py +1 -0
- api/apps/llm_app.py +2 -2
- api/db/init_data.py +19 -11
- api/db/services/dialog_service.py +5 -1
- api/db/services/llm_service.py +8 -3
- api/db/services/task_service.py +1 -1
- rag/llm/__init__.py +2 -2
- rag/llm/embedding_model.py +49 -14
- rag/nlp/search.py +1 -1
- rag/svr/task_executor.py +2 -1
- requirements.txt +2 -0
README.md
CHANGED
@@ -55,6 +55,8 @@
|
|
55 |
|
56 |
## 📌 Latest Features
|
57 |
|
|
|
|
|
58 |
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
|
59 |
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
|
60 |
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.
|
|
|
55 |
|
56 |
## 📌 Latest Features
|
57 |
|
58 |
+
- 2024-04-16 Add an embedding model 'bce-embedding-base_v1' from [QAnything](https://github.com/netease-youdao/QAnything).
|
59 |
+
- 2024-04-16 Add [FastEmbed](https://github.com/qdrant/fastembed) is designed for light and speeding embedding.
|
60 |
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
|
61 |
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
|
62 |
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.
|
README_ja.md
CHANGED
@@ -55,6 +55,8 @@
|
|
55 |
|
56 |
## 📌 最新の機能
|
57 |
|
|
|
|
|
58 |
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
|
59 |
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
|
60 |
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
|
|
|
55 |
|
56 |
## 📌 最新の機能
|
57 |
|
58 |
+
- 2024-04-16 [QAnything](https://github.com/netease-youdao/QAnything) から埋め込みモデル「bce-embedding-base_v1」を追加します。
|
59 |
+
- 2024-04-16 [FastEmbed](https://github.com/qdrant/fastembed) は、軽量かつ高速な埋め込み用に設計されています。
|
60 |
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
|
61 |
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
|
62 |
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
|
README_zh.md
CHANGED
@@ -55,6 +55,8 @@
|
|
55 |
|
56 |
## 📌 新增功能
|
57 |
|
|
|
|
|
58 |
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
|
59 |
- 2024-04-10 为‘Laws’版面分析增加了底层模型。
|
60 |
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。
|
|
|
55 |
|
56 |
## 📌 新增功能
|
57 |
|
58 |
+
- 2024-04-16 添加嵌入模型 [QAnything的bce-embedding-base_v1](https://github.com/netease-youdao/QAnything) 。
|
59 |
+
- 2024-04-16 添加 [FastEmbed](https://github.com/qdrant/fastembed) 专为轻型和高速嵌入而设计。
|
60 |
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
|
61 |
- 2024-04-10 为‘Laws’版面分析增加了底层模型。
|
62 |
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。
|
api/apps/chunk_app.py
CHANGED
@@ -252,7 +252,7 @@ def retrieval_test():
|
|
252 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
253 |
|
254 |
embd_mdl = TenantLLMService.model_instance(
|
255 |
-
kb.tenant_id, LLMType.EMBEDDING.value)
|
256 |
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
257 |
vector_similarity_weight, top, doc_ids)
|
258 |
for c in ranks["chunks"]:
|
|
|
252 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
253 |
|
254 |
embd_mdl = TenantLLMService.model_instance(
|
255 |
+
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
256 |
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
257 |
vector_similarity_weight, top, doc_ids)
|
258 |
for c in ranks["chunks"]:
|
api/apps/document_app.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
#
|
16 |
|
17 |
import base64
|
|
|
18 |
import pathlib
|
19 |
import re
|
20 |
|
|
|
15 |
#
|
16 |
|
17 |
import base64
|
18 |
+
import os
|
19 |
import pathlib
|
20 |
import re
|
21 |
|
api/apps/llm_app.py
CHANGED
@@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
|
28 |
def factories():
|
29 |
try:
|
30 |
fac = LLMFactoriesService.get_all()
|
31 |
-
return get_json_result(data=[f.to_dict() for f in fac])
|
32 |
except Exception as e:
|
33 |
return server_error_response(e)
|
34 |
|
@@ -174,7 +174,7 @@ def list():
|
|
174 |
llms = [m.to_dict()
|
175 |
for m in llms if m.status == StatusEnum.VALID.value]
|
176 |
for m in llms:
|
177 |
-
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
178 |
|
179 |
llm_set = set([m["llm_name"] for m in llms])
|
180 |
for o in objs:
|
|
|
28 |
def factories():
|
29 |
try:
|
30 |
fac = LLMFactoriesService.get_all()
|
31 |
+
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
|
32 |
except Exception as e:
|
33 |
return server_error_response(e)
|
34 |
|
|
|
174 |
llms = [m.to_dict()
|
175 |
for m in llms if m.status == StatusEnum.VALID.value]
|
176 |
for m in llms:
|
177 |
+
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
|
178 |
|
179 |
llm_set = set([m["llm_name"] for m in llms])
|
180 |
for o in objs:
|
api/db/init_data.py
CHANGED
@@ -18,7 +18,7 @@ import time
|
|
18 |
import uuid
|
19 |
|
20 |
from api.db import LLMType, UserTenantRole
|
21 |
-
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
|
22 |
from api.db.services import UserService
|
23 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
24 |
from api.db.services.user_service import TenantService, UserTenantService
|
@@ -114,12 +114,16 @@ factory_infos = [{
|
|
114 |
"logo": "",
|
115 |
"tags": "TEXT EMBEDDING",
|
116 |
"status": "1",
|
117 |
-
},
|
118 |
-
{
|
119 |
"name": "Xinference",
|
120 |
"logo": "",
|
121 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
122 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
123 |
},
|
124 |
# {
|
125 |
# "name": "文心一言",
|
@@ -254,12 +258,6 @@ def init_llm_factory():
|
|
254 |
"tags": "LLM,CHAT,",
|
255 |
"max_tokens": 7900,
|
256 |
"model_type": LLMType.CHAT.value
|
257 |
-
}, {
|
258 |
-
"fid": factory_infos[4]["name"],
|
259 |
-
"llm_name": "flag-embedding",
|
260 |
-
"tags": "TEXT EMBEDDING,",
|
261 |
-
"max_tokens": 128 * 1000,
|
262 |
-
"model_type": LLMType.EMBEDDING.value
|
263 |
}, {
|
264 |
"fid": factory_infos[4]["name"],
|
265 |
"llm_name": "moonshot-v1-32k",
|
@@ -325,6 +323,14 @@ def init_llm_factory():
|
|
325 |
"max_tokens": 2147483648,
|
326 |
"model_type": LLMType.EMBEDDING.value
|
327 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
]
|
329 |
for info in factory_infos:
|
330 |
try:
|
@@ -337,8 +343,10 @@ def init_llm_factory():
|
|
337 |
except Exception as e:
|
338 |
pass
|
339 |
|
340 |
-
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
|
341 |
-
LLMService.filter_delete([LLM.fid=="Local"])
|
|
|
|
|
342 |
|
343 |
"""
|
344 |
drop table llm;
|
|
|
18 |
import uuid
|
19 |
|
20 |
from api.db import LLMType, UserTenantRole
|
21 |
+
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
22 |
from api.db.services import UserService
|
23 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
24 |
from api.db.services.user_service import TenantService, UserTenantService
|
|
|
114 |
"logo": "",
|
115 |
"tags": "TEXT EMBEDDING",
|
116 |
"status": "1",
|
117 |
+
}, {
|
|
|
118 |
"name": "Xinference",
|
119 |
"logo": "",
|
120 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
121 |
"status": "1",
|
122 |
+
},{
|
123 |
+
"name": "QAnything",
|
124 |
+
"logo": "",
|
125 |
+
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
126 |
+
"status": "1",
|
127 |
},
|
128 |
# {
|
129 |
# "name": "文心一言",
|
|
|
258 |
"tags": "LLM,CHAT,",
|
259 |
"max_tokens": 7900,
|
260 |
"model_type": LLMType.CHAT.value
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
}, {
|
262 |
"fid": factory_infos[4]["name"],
|
263 |
"llm_name": "moonshot-v1-32k",
|
|
|
323 |
"max_tokens": 2147483648,
|
324 |
"model_type": LLMType.EMBEDDING.value
|
325 |
},
|
326 |
+
# ------------------------ QAnything -----------------------
|
327 |
+
{
|
328 |
+
"fid": factory_infos[7]["name"],
|
329 |
+
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
330 |
+
"tags": "TEXT EMBEDDING,",
|
331 |
+
"max_tokens": 512,
|
332 |
+
"model_type": LLMType.EMBEDDING.value
|
333 |
+
},
|
334 |
]
|
335 |
for info in factory_infos:
|
336 |
try:
|
|
|
343 |
except Exception as e:
|
344 |
pass
|
345 |
|
346 |
+
LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
|
347 |
+
LLMService.filter_delete([LLM.fid == "Local"])
|
348 |
+
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
349 |
+
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
350 |
|
351 |
"""
|
352 |
drop table llm;
|
api/db/services/dialog_service.py
CHANGED
@@ -80,8 +80,12 @@ def chat(dialog, messages, **kwargs):
|
|
80 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
81 |
max_tokens = 1024
|
82 |
else: max_tokens = llm[0].max_tokens
|
|
|
|
|
|
|
|
|
83 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
84 |
-
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
85 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
86 |
|
87 |
prompt_config = dialog.prompt_config
|
|
|
80 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
81 |
max_tokens = 1024
|
82 |
else: max_tokens = llm[0].max_tokens
|
83 |
+
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
84 |
+
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
85 |
+
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
|
86 |
+
|
87 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
88 |
+
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
89 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
90 |
|
91 |
prompt_config = dialog.prompt_config
|
api/db/services/llm_service.py
CHANGED
@@ -66,7 +66,7 @@ class TenantLLMService(CommonService):
|
|
66 |
raise LookupError("Tenant not found")
|
67 |
|
68 |
if llm_type == LLMType.EMBEDDING.value:
|
69 |
-
mdlnm = tenant.embd_id
|
70 |
elif llm_type == LLMType.SPEECH2TEXT.value:
|
71 |
mdlnm = tenant.asr_id
|
72 |
elif llm_type == LLMType.IMAGE2TEXT.value:
|
@@ -77,9 +77,14 @@ class TenantLLMService(CommonService):
|
|
77 |
assert False, "LLM type error"
|
78 |
|
79 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
|
80 |
if not model_config:
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
if llm_type == LLMType.EMBEDDING.value:
|
84 |
if model_config["llm_factory"] not in EmbeddingModel:
|
85 |
return
|
|
|
66 |
raise LookupError("Tenant not found")
|
67 |
|
68 |
if llm_type == LLMType.EMBEDDING.value:
|
69 |
+
mdlnm = tenant.embd_id if not llm_name else llm_name
|
70 |
elif llm_type == LLMType.SPEECH2TEXT.value:
|
71 |
mdlnm = tenant.asr_id
|
72 |
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
|
77 |
assert False, "LLM type error"
|
78 |
|
79 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
80 |
+
if model_config: model_config = model_config.to_dict()
|
81 |
if not model_config:
|
82 |
+
if llm_type == LLMType.EMBEDDING.value:
|
83 |
+
llm = LLMService.query(llm_name=llm_name)
|
84 |
+
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
|
85 |
+
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
86 |
+
if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm))
|
87 |
+
|
88 |
if llm_type == LLMType.EMBEDDING.value:
|
89 |
if model_config["llm_factory"] not in EmbeddingModel:
|
90 |
return
|
api/db/services/task_service.py
CHANGED
@@ -41,7 +41,7 @@ class TaskService(CommonService):
|
|
41 |
Document.size,
|
42 |
Knowledgebase.tenant_id,
|
43 |
Knowledgebase.language,
|
44 |
-
|
45 |
Tenant.img2txt_id,
|
46 |
Tenant.asr_id,
|
47 |
cls.model.update_time]
|
|
|
41 |
Document.size,
|
42 |
Knowledgebase.tenant_id,
|
43 |
Knowledgebase.language,
|
44 |
+
Knowledgebase.embd_id,
|
45 |
Tenant.img2txt_id,
|
46 |
Tenant.asr_id,
|
47 |
cls.model.update_time]
|
rag/llm/__init__.py
CHANGED
@@ -24,8 +24,8 @@ EmbeddingModel = {
|
|
24 |
"Xinference": XinferenceEmbed,
|
25 |
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
26 |
"ZHIPU-AI": ZhipuEmbed,
|
27 |
-
"
|
28 |
-
"
|
29 |
}
|
30 |
|
31 |
|
|
|
24 |
"Xinference": XinferenceEmbed,
|
25 |
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
26 |
"ZHIPU-AI": ZhipuEmbed,
|
27 |
+
"FastEmbed": FastEmbed,
|
28 |
+
"QAnything": QAnythingEmbed
|
29 |
}
|
30 |
|
31 |
|
rag/llm/embedding_model.py
CHANGED
@@ -20,7 +20,6 @@ from abc import ABC
|
|
20 |
from ollama import Client
|
21 |
import dashscope
|
22 |
from openai import OpenAI
|
23 |
-
from fastembed import TextEmbedding
|
24 |
from FlagEmbedding import FlagModel
|
25 |
import torch
|
26 |
import numpy as np
|
@@ -28,16 +27,17 @@ import numpy as np
|
|
28 |
from api.utils.file_utils import get_project_base_directory
|
29 |
from rag.utils import num_tokens_from_string
|
30 |
|
|
|
31 |
try:
|
32 |
flag_model = FlagModel(os.path.join(
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
except Exception as e:
|
38 |
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
|
39 |
-
|
40 |
-
|
41 |
|
42 |
|
43 |
class Base(ABC):
|
@@ -82,8 +82,10 @@ class HuEmbedding(Base):
|
|
82 |
|
83 |
|
84 |
class OpenAIEmbed(Base):
|
85 |
-
def __init__(self, key, model_name="text-embedding-ada-002",
|
86 |
-
|
|
|
|
|
87 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
88 |
self.model_name = model_name
|
89 |
|
@@ -142,7 +144,7 @@ class ZhipuEmbed(Base):
|
|
142 |
tks_num = 0
|
143 |
for txt in texts:
|
144 |
res = self.client.embeddings.create(input=txt,
|
145 |
-
|
146 |
arr.append(res.data[0].embedding)
|
147 |
tks_num += res.usage.total_tokens
|
148 |
return np.array(arr), tks_num
|
@@ -163,14 +165,14 @@ class OllamaEmbed(Base):
|
|
163 |
tks_num = 0
|
164 |
for txt in texts:
|
165 |
res = self.client.embeddings(prompt=txt,
|
166 |
-
|
167 |
arr.append(res["embedding"])
|
168 |
tks_num += 128
|
169 |
return np.array(arr), tks_num
|
170 |
|
171 |
def encode_queries(self, text):
|
172 |
res = self.client.embeddings(prompt=text,
|
173 |
-
|
174 |
return np.array(res["embedding"]), 128
|
175 |
|
176 |
|
@@ -183,10 +185,12 @@ class FastEmbed(Base):
|
|
183 |
threads: Optional[int] = None,
|
184 |
**kwargs,
|
185 |
):
|
|
|
186 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
187 |
|
188 |
def encode(self, texts: list, batch_size=32):
|
189 |
-
# Using the internal tokenizer to encode the texts and get the total
|
|
|
190 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
191 |
total_tokens = sum(len(e) for e in encodings)
|
192 |
|
@@ -195,7 +199,8 @@ class FastEmbed(Base):
|
|
195 |
return np.array(embeddings), total_tokens
|
196 |
|
197 |
def encode_queries(self, text: str):
|
198 |
-
# Using the internal tokenizer to encode the texts and get the total
|
|
|
199 |
encoding = self._model.model.tokenizer.encode(text)
|
200 |
embedding = next(self._model.query_embed(text)).tolist()
|
201 |
|
@@ -218,3 +223,33 @@ class XinferenceEmbed(Base):
|
|
218 |
model=self.model_name)
|
219 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
from ollama import Client
|
21 |
import dashscope
|
22 |
from openai import OpenAI
|
|
|
23 |
from FlagEmbedding import FlagModel
|
24 |
import torch
|
25 |
import numpy as np
|
|
|
27 |
from api.utils.file_utils import get_project_base_directory
|
28 |
from rag.utils import num_tokens_from_string
|
29 |
|
30 |
+
|
31 |
try:
|
32 |
flag_model = FlagModel(os.path.join(
|
33 |
+
get_project_base_directory(),
|
34 |
+
"rag/res/bge-large-zh-v1.5"),
|
35 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
36 |
+
use_fp16=torch.cuda.is_available())
|
37 |
except Exception as e:
|
38 |
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
|
39 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
40 |
+
use_fp16=torch.cuda.is_available())
|
41 |
|
42 |
|
43 |
class Base(ABC):
|
|
|
82 |
|
83 |
|
84 |
class OpenAIEmbed(Base):
|
85 |
+
def __init__(self, key, model_name="text-embedding-ada-002",
|
86 |
+
base_url="https://api.openai.com/v1"):
|
87 |
+
if not base_url:
|
88 |
+
base_url = "https://api.openai.com/v1"
|
89 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
90 |
self.model_name = model_name
|
91 |
|
|
|
144 |
tks_num = 0
|
145 |
for txt in texts:
|
146 |
res = self.client.embeddings.create(input=txt,
|
147 |
+
model=self.model_name)
|
148 |
arr.append(res.data[0].embedding)
|
149 |
tks_num += res.usage.total_tokens
|
150 |
return np.array(arr), tks_num
|
|
|
165 |
tks_num = 0
|
166 |
for txt in texts:
|
167 |
res = self.client.embeddings(prompt=txt,
|
168 |
+
model=self.model_name)
|
169 |
arr.append(res["embedding"])
|
170 |
tks_num += 128
|
171 |
return np.array(arr), tks_num
|
172 |
|
173 |
def encode_queries(self, text):
|
174 |
res = self.client.embeddings(prompt=text,
|
175 |
+
model=self.model_name)
|
176 |
return np.array(res["embedding"]), 128
|
177 |
|
178 |
|
|
|
185 |
threads: Optional[int] = None,
|
186 |
**kwargs,
|
187 |
):
|
188 |
+
from fastembed import TextEmbedding
|
189 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
190 |
|
191 |
def encode(self, texts: list, batch_size=32):
|
192 |
+
# Using the internal tokenizer to encode the texts and get the total
|
193 |
+
# number of tokens
|
194 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
195 |
total_tokens = sum(len(e) for e in encodings)
|
196 |
|
|
|
199 |
return np.array(embeddings), total_tokens
|
200 |
|
201 |
def encode_queries(self, text: str):
|
202 |
+
# Using the internal tokenizer to encode the texts and get the total
|
203 |
+
# number of tokens
|
204 |
encoding = self._model.model.tokenizer.encode(text)
|
205 |
embedding = next(self._model.query_embed(text)).tolist()
|
206 |
|
|
|
223 |
model=self.model_name)
|
224 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
225 |
|
226 |
+
|
227 |
+
class QAnythingEmbed(Base):
|
228 |
+
_client = None
|
229 |
+
|
230 |
+
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
231 |
+
from BCEmbedding import EmbeddingModel as qanthing
|
232 |
+
if not QAnythingEmbed._client:
|
233 |
+
try:
|
234 |
+
print("LOADING BCE...")
|
235 |
+
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
|
236 |
+
get_project_base_directory(),
|
237 |
+
"rag/res/bce-embedding-base_v1"))
|
238 |
+
except Exception as e:
|
239 |
+
QAnythingEmbed._client = qanthing(
|
240 |
+
model_name_or_path=model_name.replace(
|
241 |
+
"maidalun1020", "InfiniFlow"))
|
242 |
+
|
243 |
+
def encode(self, texts: list, batch_size=10):
|
244 |
+
res = []
|
245 |
+
token_count = 0
|
246 |
+
for t in texts:
|
247 |
+
token_count += num_tokens_from_string(t)
|
248 |
+
for i in range(0, len(texts), batch_size):
|
249 |
+
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
|
250 |
+
res.extend(embds)
|
251 |
+
return np.array(res), token_count
|
252 |
+
|
253 |
+
def encode_queries(self, text):
|
254 |
+
embds = QAnythingEmbed._client.encode([text])
|
255 |
+
return np.array(embds[0]), num_tokens_from_string(text)
|
rag/nlp/search.py
CHANGED
@@ -46,7 +46,7 @@ class Dealer:
|
|
46 |
"k": topk,
|
47 |
"similarity": sim,
|
48 |
"num_candidates": topk * 2,
|
49 |
-
"query_vector":
|
50 |
}
|
51 |
|
52 |
def search(self, req, idxnm, emb_mdl=None):
|
|
|
46 |
"k": topk,
|
47 |
"similarity": sim,
|
48 |
"num_candidates": topk * 2,
|
49 |
+
"query_vector": [float(v) for v in qv]
|
50 |
}
|
51 |
|
52 |
def search(self, req, idxnm, emb_mdl=None):
|
rag/svr/task_executor.py
CHANGED
@@ -244,8 +244,9 @@ def main(comm, mod):
|
|
244 |
for _, r in rows.iterrows():
|
245 |
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
|
246 |
try:
|
247 |
-
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
|
248 |
except Exception as e:
|
|
|
249 |
callback(prog=-1, msg=str(e))
|
250 |
continue
|
251 |
|
|
|
244 |
for _, r in rows.iterrows():
|
245 |
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
|
246 |
try:
|
247 |
+
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
|
248 |
except Exception as e:
|
249 |
+
traceback.print_stack(e)
|
250 |
callback(prog=-1, msg=str(e))
|
251 |
continue
|
252 |
|
requirements.txt
CHANGED
@@ -132,3 +132,5 @@ xpinyin==0.7.6
|
|
132 |
xxhash==3.4.1
|
133 |
yarl==1.9.4
|
134 |
zhipuai==2.0.1
|
|
|
|
|
|
132 |
xxhash==3.4.1
|
133 |
yarl==1.9.4
|
134 |
zhipuai==2.0.1
|
135 |
+
BCEmbedding
|
136 |
+
loguru==0.7.2
|