KevinHuSh
commited on
Commit
·
3198faf
1
Parent(s):
3079197
add alot of api (#23)
Browse files* clean rust version project
* clean rust version project
* build python version rag-flow
* add alot of api
- rag/llm/embedding_model.py +1 -1
- rag/nlp/huchunk.py +6 -3
- rag/nlp/search.py +1 -1
- rag/svr/parse_user_docs.py +35 -19
- rag/utils/__init__.py +19 -0
- rag/utils/es_conn.py +1 -0
- web_server/apps/document_app.py +47 -2
- web_server/apps/kb_app.py +14 -2
- web_server/apps/llm_app.py +95 -0
- web_server/apps/user_app.py +33 -5
- web_server/db/db_models.py +7 -4
- web_server/db/services/document_service.py +27 -13
- web_server/db/services/kb_service.py +33 -6
- web_server/db/services/llm_service.py +18 -0
- web_server/db/services/user_service.py +1 -1
- web_server/utils/file_utils.py +1 -1
rag/llm/embedding_model.py
CHANGED
@@ -35,7 +35,7 @@ class Base(ABC):
|
|
35 |
|
36 |
|
37 |
class HuEmbedding(Base):
|
38 |
-
def __init__(self):
|
39 |
"""
|
40 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
41 |
|
|
|
35 |
|
36 |
|
37 |
class HuEmbedding(Base):
|
38 |
+
def __init__(self, key="", model_name=""):
|
39 |
"""
|
40 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
41 |
|
rag/nlp/huchunk.py
CHANGED
@@ -411,9 +411,12 @@ class TextChunker(HuChunker):
|
|
411 |
flds = self.Fields()
|
412 |
if self.is_binary_file(fnm):
|
413 |
return flds
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
417 |
flds.table_chunks = []
|
418 |
return flds
|
419 |
|
|
|
411 |
flds = self.Fields()
|
412 |
if self.is_binary_file(fnm):
|
413 |
return flds
|
414 |
+
txt = ""
|
415 |
+
if isinstance(fnm, str):
|
416 |
+
with open(fnm, "r") as f:
|
417 |
+
txt = f.read()
|
418 |
+
else: txt = fnm.decode("utf-8")
|
419 |
+
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
420 |
flds.table_chunks = []
|
421 |
return flds
|
422 |
|
rag/nlp/search.py
CHANGED
@@ -8,7 +8,7 @@ from rag.nlp import huqie, query
|
|
8 |
import numpy as np
|
9 |
|
10 |
|
11 |
-
def index_name(uid): return f"
|
12 |
|
13 |
|
14 |
class Dealer:
|
|
|
8 |
import numpy as np
|
9 |
|
10 |
|
11 |
+
def index_name(uid): return f"ragflow_{uid}"
|
12 |
|
13 |
|
14 |
class Dealer:
|
rag/svr/parse_user_docs.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
|
|
17 |
import os
|
18 |
import hashlib
|
19 |
import copy
|
@@ -24,9 +25,10 @@ from timeit import default_timer as timer
|
|
24 |
|
25 |
from rag.llm import EmbeddingModel, CvModel
|
26 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
27 |
-
from rag.utils import ELASTICSEARCH
|
28 |
from rag.utils import MINIO
|
29 |
-
from rag.utils import rmSpace,
|
|
|
30 |
from rag.nlp import huchunk, huqie, search
|
31 |
from io import BytesIO
|
32 |
import pandas as pd
|
@@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
|
|
47 |
from web_server.db import LLMType
|
48 |
from web_server.db.services.document_service import DocumentService
|
49 |
from web_server.db.services.llm_service import TenantLLMService
|
|
|
50 |
from web_server.utils import get_format_time
|
51 |
from web_server.utils.file_utils import get_project_base_directory
|
52 |
|
@@ -83,7 +86,7 @@ def collect(comm, mod, tm):
|
|
83 |
if len(docs) == 0:
|
84 |
return pd.DataFrame()
|
85 |
docs = pd.DataFrame(docs)
|
86 |
-
mtm =
|
87 |
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
88 |
return docs
|
89 |
|
@@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
|
|
99 |
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
100 |
|
101 |
|
102 |
-
def build(row):
|
103 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
104 |
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
105 |
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
106 |
return []
|
|
|
107 |
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
108 |
if ELASTICSEARCH.getTotal(res) > 0:
|
109 |
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
@@ -120,7 +124,8 @@ def build(row):
|
|
120 |
set_progress(row["id"], random.randint(0, 20) /
|
121 |
100., "Finished preparing! Start to slice file!", True)
|
122 |
try:
|
123 |
-
|
|
|
124 |
except Exception as e:
|
125 |
if re.search("(No such file|not found)", str(e)):
|
126 |
set_progress(
|
@@ -131,6 +136,9 @@ def build(row):
|
|
131 |
row["id"], -1, f"Internal server error: %s" %
|
132 |
str(e).replace(
|
133 |
"'", ""))
|
|
|
|
|
|
|
134 |
return []
|
135 |
|
136 |
if not obj.text_chunks and not obj.table_chunks:
|
@@ -144,7 +152,7 @@ def build(row):
|
|
144 |
"Finished slicing files. Start to embedding the content.")
|
145 |
|
146 |
doc = {
|
147 |
-
"doc_id": row["
|
148 |
"kb_id": [str(row["kb_id"])],
|
149 |
"docnm_kwd": os.path.split(row["location"])[-1],
|
150 |
"title_tks": huqie.qie(row["name"]),
|
@@ -164,10 +172,10 @@ def build(row):
|
|
164 |
docs.append(d)
|
165 |
continue
|
166 |
|
167 |
-
if isinstance(img,
|
168 |
-
img.save(output_buffer, format='JPEG')
|
169 |
-
else:
|
170 |
output_buffer = BytesIO(img)
|
|
|
|
|
171 |
|
172 |
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
173 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
@@ -215,15 +223,16 @@ def embedding(docs, mdl):
|
|
215 |
|
216 |
|
217 |
def model_instance(tenant_id, llm_type):
|
218 |
-
model_config = TenantLLMService.
|
219 |
-
if not model_config:
|
220 |
-
|
|
|
221 |
if llm_type == LLMType.EMBEDDING:
|
222 |
-
if model_config
|
223 |
-
return EmbeddingModel[model_config
|
224 |
if llm_type == LLMType.IMAGE2TEXT:
|
225 |
-
if model_config
|
226 |
-
return CvModel[model_config.llm_factory](model_config
|
227 |
|
228 |
|
229 |
def main(comm, mod):
|
@@ -231,7 +240,7 @@ def main(comm, mod):
|
|
231 |
from rag.llm import HuEmbedding
|
232 |
model = HuEmbedding()
|
233 |
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
234 |
-
tm =
|
235 |
rows = collect(comm, mod, tm)
|
236 |
if len(rows) == 0:
|
237 |
return
|
@@ -247,7 +256,7 @@ def main(comm, mod):
|
|
247 |
st_tm = timer()
|
248 |
cks = build(r, cv_mdl)
|
249 |
if not cks:
|
250 |
-
tmf.write(str(r["
|
251 |
continue
|
252 |
# TODO: exception handler
|
253 |
## set_progress(r["did"], -1, "ERROR: ")
|
@@ -268,12 +277,19 @@ def main(comm, mod):
|
|
268 |
cron_logger.error(str(es_r))
|
269 |
else:
|
270 |
set_progress(r["id"], 1., "Done!")
|
271 |
-
DocumentService.
|
|
|
|
|
272 |
tmf.write(str(r["update_time"]) + "\n")
|
273 |
tmf.close()
|
274 |
|
275 |
|
276 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
277 |
from mpi4py import MPI
|
278 |
comm = MPI.COMM_WORLD
|
279 |
main(comm.Get_size(), comm.Get_rank())
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
import json
|
17 |
+
import logging
|
18 |
import os
|
19 |
import hashlib
|
20 |
import copy
|
|
|
25 |
|
26 |
from rag.llm import EmbeddingModel, CvModel
|
27 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
28 |
+
from rag.utils import ELASTICSEARCH
|
29 |
from rag.utils import MINIO
|
30 |
+
from rag.utils import rmSpace, findMaxTm
|
31 |
+
|
32 |
from rag.nlp import huchunk, huqie, search
|
33 |
from io import BytesIO
|
34 |
import pandas as pd
|
|
|
49 |
from web_server.db import LLMType
|
50 |
from web_server.db.services.document_service import DocumentService
|
51 |
from web_server.db.services.llm_service import TenantLLMService
|
52 |
+
from web_server.settings import database_logger
|
53 |
from web_server.utils import get_format_time
|
54 |
from web_server.utils.file_utils import get_project_base_directory
|
55 |
|
|
|
86 |
if len(docs) == 0:
|
87 |
return pd.DataFrame()
|
88 |
docs = pd.DataFrame(docs)
|
89 |
+
mtm = docs["update_time"].max()
|
90 |
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
91 |
return docs
|
92 |
|
|
|
102 |
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
103 |
|
104 |
|
105 |
+
def build(row, cvmdl):
|
106 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
107 |
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
108 |
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
109 |
return []
|
110 |
+
|
111 |
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
112 |
if ELASTICSEARCH.getTotal(res) > 0:
|
113 |
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
|
|
124 |
set_progress(row["id"], random.randint(0, 20) /
|
125 |
100., "Finished preparing! Start to slice file!", True)
|
126 |
try:
|
127 |
+
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
128 |
+
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
|
129 |
except Exception as e:
|
130 |
if re.search("(No such file|not found)", str(e)):
|
131 |
set_progress(
|
|
|
136 |
row["id"], -1, f"Internal server error: %s" %
|
137 |
str(e).replace(
|
138 |
"'", ""))
|
139 |
+
|
140 |
+
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
141 |
+
|
142 |
return []
|
143 |
|
144 |
if not obj.text_chunks and not obj.table_chunks:
|
|
|
152 |
"Finished slicing files. Start to embedding the content.")
|
153 |
|
154 |
doc = {
|
155 |
+
"doc_id": row["id"],
|
156 |
"kb_id": [str(row["kb_id"])],
|
157 |
"docnm_kwd": os.path.split(row["location"])[-1],
|
158 |
"title_tks": huqie.qie(row["name"]),
|
|
|
172 |
docs.append(d)
|
173 |
continue
|
174 |
|
175 |
+
if isinstance(img, bytes):
|
|
|
|
|
176 |
output_buffer = BytesIO(img)
|
177 |
+
else:
|
178 |
+
img.save(output_buffer, format='JPEG')
|
179 |
|
180 |
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
181 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
|
|
223 |
|
224 |
|
225 |
def model_instance(tenant_id, llm_type):
|
226 |
+
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
227 |
+
if not model_config:
|
228 |
+
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
229 |
+
else: model_config = model_config[0].to_dict()
|
230 |
if llm_type == LLMType.EMBEDDING:
|
231 |
+
if model_config["llm_factory"] not in EmbeddingModel: return
|
232 |
+
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
233 |
if llm_type == LLMType.IMAGE2TEXT:
|
234 |
+
if model_config["llm_factory"] not in CvModel: return
|
235 |
+
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
|
236 |
|
237 |
|
238 |
def main(comm, mod):
|
|
|
240 |
from rag.llm import HuEmbedding
|
241 |
model = HuEmbedding()
|
242 |
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
243 |
+
tm = findMaxTm(tm_fnm)
|
244 |
rows = collect(comm, mod, tm)
|
245 |
if len(rows) == 0:
|
246 |
return
|
|
|
256 |
st_tm = timer()
|
257 |
cks = build(r, cv_mdl)
|
258 |
if not cks:
|
259 |
+
tmf.write(str(r["update_time"]) + "\n")
|
260 |
continue
|
261 |
# TODO: exception handler
|
262 |
## set_progress(r["did"], -1, "ERROR: ")
|
|
|
277 |
cron_logger.error(str(es_r))
|
278 |
else:
|
279 |
set_progress(r["id"], 1., "Done!")
|
280 |
+
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
|
281 |
+
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
282 |
+
|
283 |
tmf.write(str(r["update_time"]) + "\n")
|
284 |
tmf.close()
|
285 |
|
286 |
|
287 |
if __name__ == "__main__":
|
288 |
+
peewee_logger = logging.getLogger('peewee')
|
289 |
+
peewee_logger.propagate = False
|
290 |
+
peewee_logger.addHandler(database_logger.handlers[0])
|
291 |
+
peewee_logger.setLevel(database_logger.level)
|
292 |
+
|
293 |
from mpi4py import MPI
|
294 |
comm = MPI.COMM_WORLD
|
295 |
main(comm.Get_size(), comm.Get_rank())
|
rag/utils/__init__.py
CHANGED
@@ -40,6 +40,25 @@ def findMaxDt(fnm):
|
|
40 |
print("WARNING: can't find " + fnm)
|
41 |
return m
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def num_tokens_from_string(string: str) -> int:
|
44 |
"""Returns the number of tokens in a text string."""
|
45 |
encoding = tiktoken.get_encoding('cl100k_base')
|
|
|
40 |
print("WARNING: can't find " + fnm)
|
41 |
return m
|
42 |
|
43 |
+
|
44 |
+
def findMaxTm(fnm):
|
45 |
+
m = 0
|
46 |
+
try:
|
47 |
+
with open(fnm, "r") as f:
|
48 |
+
while True:
|
49 |
+
l = f.readline()
|
50 |
+
if not l:
|
51 |
+
break
|
52 |
+
l = l.strip("\n")
|
53 |
+
if l == 'nan':
|
54 |
+
continue
|
55 |
+
if int(l) > m:
|
56 |
+
m = int(l)
|
57 |
+
except Exception as e:
|
58 |
+
print("WARNING: can't find " + fnm)
|
59 |
+
return m
|
60 |
+
|
61 |
+
|
62 |
def num_tokens_from_string(string: str) -> int:
|
63 |
"""Returns the number of tokens in a text string."""
|
64 |
encoding = tiktoken.get_encoding('cl100k_base')
|
rag/utils/es_conn.py
CHANGED
@@ -294,6 +294,7 @@ class HuEs:
|
|
294 |
except Exception as e:
|
295 |
es_logger.error("ES updateByQuery deleteByQuery: " +
|
296 |
str(e) + "【Q】:" + str(query.to_dict()))
|
|
|
297 |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
298 |
continue
|
299 |
|
|
|
294 |
except Exception as e:
|
295 |
es_logger.error("ES updateByQuery deleteByQuery: " +
|
296 |
str(e) + "【Q】:" + str(query.to_dict()))
|
297 |
+
if str(e).find("NotFoundError") > 0: return True
|
298 |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
299 |
continue
|
300 |
|
web_server/apps/document_app.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import pathlib
|
17 |
|
18 |
from elasticsearch_dsl import Q
|
@@ -195,11 +196,15 @@ def rm():
|
|
195 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
196 |
if not e:
|
197 |
return get_data_error_result(retmsg="Document not found!")
|
|
|
|
|
|
|
|
|
198 |
if not DocumentService.delete_by_id(req["doc_id"]):
|
199 |
return get_data_error_result(
|
200 |
retmsg="Database error (Document removal)!")
|
201 |
-
|
202 |
-
MINIO.rm(
|
203 |
return get_json_result(data=True)
|
204 |
except Exception as e:
|
205 |
return server_error_response(e)
|
@@ -233,3 +238,43 @@ def rename():
|
|
233 |
return get_json_result(data=True)
|
234 |
except Exception as e:
|
235 |
return server_error_response(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
import base64
|
17 |
import pathlib
|
18 |
|
19 |
from elasticsearch_dsl import Q
|
|
|
196 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
197 |
if not e:
|
198 |
return get_data_error_result(retmsg="Document not found!")
|
199 |
+
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
|
200 |
+
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
|
201 |
+
|
202 |
+
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
203 |
if not DocumentService.delete_by_id(req["doc_id"]):
|
204 |
return get_data_error_result(
|
205 |
retmsg="Database error (Document removal)!")
|
206 |
+
|
207 |
+
MINIO.rm(doc.kb_id, doc.location)
|
208 |
return get_json_result(data=True)
|
209 |
except Exception as e:
|
210 |
return server_error_response(e)
|
|
|
238 |
return get_json_result(data=True)
|
239 |
except Exception as e:
|
240 |
return server_error_response(e)
|
241 |
+
|
242 |
+
|
243 |
+
@manager.route('/get', methods=['GET'])
|
244 |
+
@login_required
|
245 |
+
def get():
|
246 |
+
doc_id = request.args["doc_id"]
|
247 |
+
try:
|
248 |
+
e, doc = DocumentService.get_by_id(doc_id)
|
249 |
+
if not e:
|
250 |
+
return get_data_error_result(retmsg="Document not found!")
|
251 |
+
|
252 |
+
blob = MINIO.get(doc.kb_id, doc.location)
|
253 |
+
return get_json_result(data={"base64": base64.b64decode(blob)})
|
254 |
+
except Exception as e:
|
255 |
+
return server_error_response(e)
|
256 |
+
|
257 |
+
|
258 |
+
@manager.route('/change_parser', methods=['POST'])
|
259 |
+
@login_required
|
260 |
+
@validate_request("doc_id", "parser_id")
|
261 |
+
def change_parser():
|
262 |
+
req = request.json
|
263 |
+
try:
|
264 |
+
e, doc = DocumentService.get_by_id(req["doc_id"])
|
265 |
+
if not e:
|
266 |
+
return get_data_error_result(retmsg="Document not found!")
|
267 |
+
if doc.parser_id.lower() == req["parser_id"].lower():
|
268 |
+
return get_json_result(data=True)
|
269 |
+
|
270 |
+
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
|
271 |
+
if not e:
|
272 |
+
return get_data_error_result(retmsg="Document not found!")
|
273 |
+
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
|
274 |
+
if not e:
|
275 |
+
return get_data_error_result(retmsg="Document not found!")
|
276 |
+
|
277 |
+
return get_json_result(data=True)
|
278 |
+
except Exception as e:
|
279 |
+
return server_error_response(e)
|
280 |
+
|
web_server/apps/kb_app.py
CHANGED
@@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
|
|
29 |
|
30 |
@manager.route('/create', methods=['post'])
|
31 |
@login_required
|
32 |
-
@validate_request("name", "description", "permission", "
|
33 |
def create():
|
34 |
req = request.json
|
35 |
req["name"] = req["name"].strip()
|
@@ -46,7 +46,7 @@ def create():
|
|
46 |
|
47 |
@manager.route('/update', methods=['post'])
|
48 |
@login_required
|
49 |
-
@validate_request("kb_id", "name", "description", "permission", "
|
50 |
def update():
|
51 |
req = request.json
|
52 |
req["name"] = req["name"].strip()
|
@@ -72,6 +72,18 @@ def update():
|
|
72 |
return server_error_response(e)
|
73 |
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
@manager.route('/list', methods=['GET'])
|
76 |
@login_required
|
77 |
def list():
|
|
|
29 |
|
30 |
@manager.route('/create', methods=['post'])
|
31 |
@login_required
|
32 |
+
@validate_request("name", "description", "permission", "parser_id")
|
33 |
def create():
|
34 |
req = request.json
|
35 |
req["name"] = req["name"].strip()
|
|
|
46 |
|
47 |
@manager.route('/update', methods=['post'])
|
48 |
@login_required
|
49 |
+
@validate_request("kb_id", "name", "description", "permission", "parser_id")
|
50 |
def update():
|
51 |
req = request.json
|
52 |
req["name"] = req["name"].strip()
|
|
|
72 |
return server_error_response(e)
|
73 |
|
74 |
|
75 |
+
@manager.route('/detail', methods=['GET'])
|
76 |
+
@login_required
|
77 |
+
def detail():
|
78 |
+
kb_id = request.args["kb_id"]
|
79 |
+
try:
|
80 |
+
kb = KnowledgebaseService.get_detail(kb_id)
|
81 |
+
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
82 |
+
return get_json_result(data=kb)
|
83 |
+
except Exception as e:
|
84 |
+
return server_error_response(e)
|
85 |
+
|
86 |
+
|
87 |
@manager.route('/list', methods=['GET'])
|
88 |
@login_required
|
89 |
def list():
|
web_server/apps/llm_app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
from flask import request
|
17 |
+
from flask_login import login_required, current_user
|
18 |
+
|
19 |
+
from web_server.db.services import duplicate_name
|
20 |
+
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
21 |
+
from web_server.db.services.user_service import TenantService, UserTenantService
|
22 |
+
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
23 |
+
from web_server.utils import get_uuid, get_format_time
|
24 |
+
from web_server.db import StatusEnum, UserTenantRole
|
25 |
+
from web_server.db.services.kb_service import KnowledgebaseService
|
26 |
+
from web_server.db.db_models import Knowledgebase, TenantLLM
|
27 |
+
from web_server.settings import stat_logger, RetCode
|
28 |
+
from web_server.utils.api_utils import get_json_result
|
29 |
+
|
30 |
+
|
31 |
+
@manager.route('/factories', methods=['GET'])
|
32 |
+
@login_required
|
33 |
+
def factories():
|
34 |
+
try:
|
35 |
+
fac = LLMFactoriesService.get_all()
|
36 |
+
return get_json_result(data=fac.to_json())
|
37 |
+
except Exception as e:
|
38 |
+
return server_error_response(e)
|
39 |
+
|
40 |
+
|
41 |
+
@manager.route('/set_api_key', methods=['POST'])
|
42 |
+
@login_required
|
43 |
+
@validate_request("llm_factory", "api_key")
|
44 |
+
def set_api_key():
|
45 |
+
req = request.json
|
46 |
+
llm = {
|
47 |
+
"tenant_id": current_user.id,
|
48 |
+
"llm_factory": req["llm_factory"],
|
49 |
+
"api_key": req["api_key"]
|
50 |
+
}
|
51 |
+
# TODO: Test api_key
|
52 |
+
for n in ["model_type", "llm_name"]:
|
53 |
+
if n in req: llm[n] = req[n]
|
54 |
+
|
55 |
+
TenantLLM.insert(**llm).on_conflict("replace").execute()
|
56 |
+
return get_json_result(data=True)
|
57 |
+
|
58 |
+
|
59 |
+
@manager.route('/my_llms', methods=['GET'])
|
60 |
+
@login_required
|
61 |
+
def my_llms():
|
62 |
+
try:
|
63 |
+
objs = TenantLLMService.query(tenant_id=current_user.id)
|
64 |
+
objs = [o.to_dict() for o in objs]
|
65 |
+
for o in objs: del o["api_key"]
|
66 |
+
return get_json_result(data=objs)
|
67 |
+
except Exception as e:
|
68 |
+
return server_error_response(e)
|
69 |
+
|
70 |
+
|
71 |
+
@manager.route('/list', methods=['GET'])
|
72 |
+
@login_required
|
73 |
+
def list():
|
74 |
+
try:
|
75 |
+
objs = TenantLLMService.query(tenant_id=current_user.id)
|
76 |
+
objs = [o.to_dict() for o in objs if o.api_key]
|
77 |
+
fct = {}
|
78 |
+
for o in objs:
|
79 |
+
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
|
80 |
+
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
|
81 |
+
|
82 |
+
llms = LLMService.get_all()
|
83 |
+
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
84 |
+
for m in llms:
|
85 |
+
m["available"] = False
|
86 |
+
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
|
87 |
+
m["available"] = True
|
88 |
+
res = {}
|
89 |
+
for m in llms:
|
90 |
+
if m["fid"] not in res: res[m["fid"]] = []
|
91 |
+
res[m["fid"]].append(m)
|
92 |
+
|
93 |
+
return get_json_result(data=res)
|
94 |
+
except Exception as e:
|
95 |
+
return server_error_response(e)
|
web_server/apps/user_app.py
CHANGED
@@ -16,9 +16,12 @@
|
|
16 |
from flask import request, session, redirect, url_for
|
17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
18 |
from flask_login import login_required, current_user, login_user, logout_user
|
|
|
|
|
|
|
19 |
from web_server.utils.api_utils import server_error_response, validate_request
|
20 |
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
21 |
-
from web_server.db import UserTenantRole
|
22 |
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
23 |
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
24 |
from web_server.settings import stat_logger
|
@@ -47,8 +50,9 @@ def login():
|
|
47 |
avatar = download_img(userinfo["avatar_url"])
|
48 |
except Exception as e:
|
49 |
stat_logger.exception(e)
|
|
|
50 |
try:
|
51 |
-
users = user_register({
|
52 |
"access_token": session["access_token"],
|
53 |
"email": userinfo["email"],
|
54 |
"avatar": avatar,
|
@@ -63,6 +67,7 @@ def login():
|
|
63 |
login_user(user)
|
64 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
65 |
except Exception as e:
|
|
|
66 |
stat_logger.exception(e)
|
67 |
return server_error_response(e)
|
68 |
elif not request.json:
|
@@ -162,7 +167,25 @@ def user_info():
|
|
162 |
return get_json_result(data=current_user.to_dict())
|
163 |
|
164 |
|
165 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
user_id = get_uuid()
|
167 |
user["id"] = user_id
|
168 |
tenant = {
|
@@ -180,10 +203,12 @@ def user_register(user):
|
|
180 |
"invited_by": user_id,
|
181 |
"role": UserTenantRole.OWNER
|
182 |
}
|
|
|
183 |
|
184 |
if not UserService.save(**user):return
|
185 |
TenantService.save(**tenant)
|
186 |
UserTenantService.save(**usr_tenant)
|
|
|
187 |
return UserService.query(email=user["email"])
|
188 |
|
189 |
|
@@ -203,14 +228,17 @@ def user_add():
|
|
203 |
"last_login_time": get_format_time(),
|
204 |
"is_superuser": False,
|
205 |
}
|
|
|
|
|
206 |
try:
|
207 |
-
users = user_register(user_dict)
|
208 |
if not users: raise Exception('Register user failure.')
|
209 |
if len(users) > 1: raise Exception('Same E-mail exist!')
|
210 |
user = users[0]
|
211 |
login_user(user)
|
212 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
213 |
except Exception as e:
|
|
|
214 |
stat_logger.exception(e)
|
215 |
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
216 |
|
@@ -220,7 +248,7 @@ def user_add():
|
|
220 |
@login_required
|
221 |
def tenant_info():
|
222 |
try:
|
223 |
-
tenants = TenantService.get_by_user_id(current_user.id)
|
224 |
return get_json_result(data=tenants)
|
225 |
except Exception as e:
|
226 |
return server_error_response(e)
|
|
|
16 |
from flask import request, session, redirect, url_for
|
17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
18 |
from flask_login import login_required, current_user, login_user, logout_user
|
19 |
+
|
20 |
+
from web_server.db.db_models import TenantLLM
|
21 |
+
from web_server.db.services.llm_service import TenantLLMService
|
22 |
from web_server.utils.api_utils import server_error_response, validate_request
|
23 |
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
24 |
+
from web_server.db import UserTenantRole, LLMType
|
25 |
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
26 |
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
27 |
from web_server.settings import stat_logger
|
|
|
50 |
avatar = download_img(userinfo["avatar_url"])
|
51 |
except Exception as e:
|
52 |
stat_logger.exception(e)
|
53 |
+
user_id = get_uuid()
|
54 |
try:
|
55 |
+
users = user_register(user_id, {
|
56 |
"access_token": session["access_token"],
|
57 |
"email": userinfo["email"],
|
58 |
"avatar": avatar,
|
|
|
67 |
login_user(user)
|
68 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
69 |
except Exception as e:
|
70 |
+
rollback_user_registration(user_id)
|
71 |
stat_logger.exception(e)
|
72 |
return server_error_response(e)
|
73 |
elif not request.json:
|
|
|
167 |
return get_json_result(data=current_user.to_dict())
|
168 |
|
169 |
|
170 |
+
def rollback_user_registration(user_id):
|
171 |
+
try:
|
172 |
+
TenantService.delete_by_id(user_id)
|
173 |
+
except Exception as e:
|
174 |
+
pass
|
175 |
+
try:
|
176 |
+
u = UserTenantService.query(tenant_id=user_id)
|
177 |
+
if u:
|
178 |
+
UserTenantService.delete_by_id(u[0].id)
|
179 |
+
except Exception as e:
|
180 |
+
pass
|
181 |
+
try:
|
182 |
+
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
183 |
+
except Exception as e:
|
184 |
+
pass
|
185 |
+
|
186 |
+
|
187 |
+
def user_register(user_id, user):
|
188 |
+
|
189 |
user_id = get_uuid()
|
190 |
user["id"] = user_id
|
191 |
tenant = {
|
|
|
203 |
"invited_by": user_id,
|
204 |
"role": UserTenantRole.OWNER
|
205 |
}
|
206 |
+
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
|
207 |
|
208 |
if not UserService.save(**user):return
|
209 |
TenantService.save(**tenant)
|
210 |
UserTenantService.save(**usr_tenant)
|
211 |
+
TenantLLMService.save(**tenant_llm)
|
212 |
return UserService.query(email=user["email"])
|
213 |
|
214 |
|
|
|
228 |
"last_login_time": get_format_time(),
|
229 |
"is_superuser": False,
|
230 |
}
|
231 |
+
|
232 |
+
user_id = get_uuid()
|
233 |
try:
|
234 |
+
users = user_register(user_id, user_dict)
|
235 |
if not users: raise Exception('Register user failure.')
|
236 |
if len(users) > 1: raise Exception('Same E-mail exist!')
|
237 |
user = users[0]
|
238 |
login_user(user)
|
239 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
240 |
except Exception as e:
|
241 |
+
rollback_user_registration(user_id)
|
242 |
stat_logger.exception(e)
|
243 |
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
244 |
|
|
|
248 |
@login_required
|
249 |
def tenant_info():
|
250 |
try:
|
251 |
+
tenants = TenantService.get_by_user_id(current_user.id)[0]
|
252 |
return get_json_result(data=tenants)
|
253 |
except Exception as e:
|
254 |
return server_error_response(e)
|
web_server/db/db_models.py
CHANGED
@@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
|
|
428 |
class LLM(DataBaseModel):
|
429 |
# defautlt LLMs for every users
|
430 |
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
|
|
431 |
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
432 |
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
433 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
@@ -442,8 +443,8 @@ class LLM(DataBaseModel):
|
|
442 |
class TenantLLM(DataBaseModel):
|
443 |
tenant_id = CharField(max_length=32, null=False)
|
444 |
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
445 |
-
model_type = CharField(max_length=128, null=
|
446 |
-
llm_name = CharField(max_length=128, null=
|
447 |
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
448 |
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
449 |
|
@@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
|
|
452 |
|
453 |
class Meta:
|
454 |
db_table = "tenant_llm"
|
455 |
-
primary_key = CompositeKey('tenant_id', 'llm_factory')
|
456 |
|
457 |
|
458 |
class Knowledgebase(DataBaseModel):
|
@@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel):
|
|
464 |
permission = CharField(max_length=16, null=False, help_text="me|team")
|
465 |
created_by = CharField(max_length=32, null=False)
|
466 |
doc_num = IntegerField(default=0)
|
467 |
-
|
|
|
|
|
468 |
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
469 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
470 |
|
|
|
428 |
class LLM(DataBaseModel):
|
429 |
# defautlt LLMs for every users
|
430 |
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
431 |
+
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
432 |
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
433 |
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
434 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
|
|
443 |
class TenantLLM(DataBaseModel):
|
444 |
tenant_id = CharField(max_length=32, null=False)
|
445 |
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
446 |
+
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
447 |
+
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
448 |
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
449 |
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
450 |
|
|
|
453 |
|
454 |
class Meta:
|
455 |
db_table = "tenant_llm"
|
456 |
+
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
|
457 |
|
458 |
|
459 |
class Knowledgebase(DataBaseModel):
|
|
|
465 |
permission = CharField(max_length=16, null=False, help_text="me|team")
|
466 |
created_by = CharField(max_length=32, null=False)
|
467 |
doc_num = IntegerField(default=0)
|
468 |
+
token_num = IntegerField(default=0)
|
469 |
+
chunk_num = IntegerField(default=0)
|
470 |
+
|
471 |
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
472 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
473 |
|
web_server/db/services/document_service.py
CHANGED
@@ -13,12 +13,13 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
|
|
16 |
from web_server.db import TenantPermission, FileType
|
17 |
-
from web_server.db.db_models import DB, Knowledgebase
|
18 |
from web_server.db.db_models import Document
|
19 |
from web_server.db.services.common_service import CommonService
|
20 |
from web_server.db.services.kb_service import KnowledgebaseService
|
21 |
-
from web_server.utils import get_uuid, get_format_time
|
22 |
from web_server.db.db_utils import StatusEnum
|
23 |
|
24 |
|
@@ -61,15 +62,28 @@ class DocumentService(CommonService):
|
|
61 |
@classmethod
|
62 |
@DB.connection_context()
|
63 |
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
64 |
-
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
|
65 |
-
docs = cls.model.select(fields)
|
66 |
-
cls.model.
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
return list(docs.dicts())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from peewee import Expression
|
17 |
+
|
18 |
from web_server.db import TenantPermission, FileType
|
19 |
+
from web_server.db.db_models import DB, Knowledgebase, Tenant
|
20 |
from web_server.db.db_models import Document
|
21 |
from web_server.db.services.common_service import CommonService
|
22 |
from web_server.db.services.kb_service import KnowledgebaseService
|
|
|
23 |
from web_server.db.db_utils import StatusEnum
|
24 |
|
25 |
|
|
|
62 |
@classmethod
|
63 |
@DB.connection_context()
|
64 |
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
65 |
+
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
|
66 |
+
docs = cls.model.select(*fields) \
|
67 |
+
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
68 |
+
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
69 |
+
.where(
|
70 |
+
cls.model.status == StatusEnum.VALID.value,
|
71 |
+
~(cls.model.type == FileType.VIRTUAL.value),
|
72 |
+
cls.model.progress == 0,
|
73 |
+
cls.model.update_time >= tm,
|
74 |
+
(Expression(cls.model.create_time, "%%", comm) == mod))\
|
75 |
+
.order_by(cls.model.update_time.asc())\
|
76 |
+
.paginate(1, items_per_page)
|
77 |
return list(docs.dicts())
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
@DB.connection_context()
|
81 |
+
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
82 |
+
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
83 |
+
chunk_num=cls.model.chunk_num + chunk_num,
|
84 |
+
process_duation=cls.model.process_duation+duation).where(
|
85 |
+
cls.model.id == doc_id).execute()
|
86 |
+
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
87 |
+
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
88 |
+
return num
|
89 |
+
|
web_server/db/services/kb_service.py
CHANGED
@@ -17,7 +17,7 @@ import peewee
|
|
17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
18 |
|
19 |
from web_server.db import TenantPermission
|
20 |
-
from web_server.db.db_models import DB, UserTenant
|
21 |
from web_server.db.db_models import Knowledgebase
|
22 |
from web_server.db.services.common_service import CommonService
|
23 |
from web_server.utils import get_uuid, get_format_time
|
@@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
|
|
29 |
|
30 |
@classmethod
|
31 |
@DB.connection_context()
|
32 |
-
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
|
|
33 |
kbs = cls.model.select().where(
|
34 |
-
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
35 |
-
|
|
|
36 |
)
|
37 |
-
if desc:
|
38 |
-
|
|
|
|
|
39 |
|
40 |
kbs = kbs.paginate(page_number, items_per_page)
|
41 |
|
42 |
return list(kbs.dicts())
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
18 |
|
19 |
from web_server.db import TenantPermission
|
20 |
+
from web_server.db.db_models import DB, UserTenant, Tenant
|
21 |
from web_server.db.db_models import Knowledgebase
|
22 |
from web_server.db.services.common_service import CommonService
|
23 |
from web_server.utils import get_uuid, get_format_time
|
|
|
29 |
|
30 |
@classmethod
|
31 |
@DB.connection_context()
|
32 |
+
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
33 |
+
page_number, items_per_page, orderby, desc):
|
34 |
kbs = cls.model.select().where(
|
35 |
+
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
36 |
+
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
37 |
+
& (cls.model.status == StatusEnum.VALID.value)
|
38 |
)
|
39 |
+
if desc:
|
40 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
41 |
+
else:
|
42 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
43 |
|
44 |
kbs = kbs.paginate(page_number, items_per_page)
|
45 |
|
46 |
return list(kbs.dicts())
|
47 |
|
48 |
+
@classmethod
|
49 |
+
@DB.connection_context()
|
50 |
+
def get_detail(cls, kb_id):
|
51 |
+
fields = [
|
52 |
+
cls.model.id,
|
53 |
+
Tenant.embd_id,
|
54 |
+
cls.model.avatar,
|
55 |
+
cls.model.name,
|
56 |
+
cls.model.description,
|
57 |
+
cls.model.permission,
|
58 |
+
cls.model.doc_num,
|
59 |
+
cls.model.token_num,
|
60 |
+
cls.model.chunk_num,
|
61 |
+
cls.model.parser_id]
|
62 |
+
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
63 |
+
(cls.model.id == kb_id),
|
64 |
+
(cls.model.status == StatusEnum.VALID.value)
|
65 |
+
)
|
66 |
+
if not kbs:
|
67 |
+
return
|
68 |
+
d = kbs[0].to_dict()
|
69 |
+
d["embd_id"] = kbs[0].tenant.embd_id
|
70 |
+
return d
|
web_server/db/services/llm_service.py
CHANGED
@@ -33,3 +33,21 @@ class LLMService(CommonService):
|
|
33 |
|
34 |
class TenantLLMService(CommonService):
|
35 |
model = TenantLLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
class TenantLLMService(CommonService):
|
35 |
model = TenantLLM
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
@DB.connection_context()
|
39 |
+
def get_api_key(cls, tenant_id, model_type):
|
40 |
+
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
41 |
+
if objs and len(objs)>0 and objs[0].llm_name:
|
42 |
+
return objs[0]
|
43 |
+
|
44 |
+
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
45 |
+
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
46 |
+
(cls.model.tenant_id == tenant_id),
|
47 |
+
(cls.model.model_type == model_type),
|
48 |
+
(LLM.status == StatusEnum.VALID)
|
49 |
+
)
|
50 |
+
|
51 |
+
if not objs:return
|
52 |
+
return objs[0]
|
53 |
+
|
web_server/db/services/user_service.py
CHANGED
@@ -79,7 +79,7 @@ class TenantService(CommonService):
|
|
79 |
@classmethod
|
80 |
@DB.connection_context()
|
81 |
def get_by_user_id(cls, user_id):
|
82 |
-
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
83 |
return list(cls.model.select(*fields)\
|
84 |
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
85 |
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
|
|
79 |
@classmethod
|
80 |
@DB.connection_context()
|
81 |
def get_by_user_id(cls, user_id):
|
82 |
+
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
83 |
return list(cls.model.select(*fields)\
|
84 |
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
85 |
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
web_server/utils/file_utils.py
CHANGED
@@ -143,7 +143,7 @@ def filename_type(filename):
|
|
143 |
if re.match(r".*\.pdf$", filename):
|
144 |
return FileType.PDF.value
|
145 |
|
146 |
-
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
|
147 |
return FileType.DOC.value
|
148 |
|
149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
|
|
143 |
if re.match(r".*\.pdf$", filename):
|
144 |
return FileType.PDF.value
|
145 |
|
146 |
+
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
147 |
return FileType.DOC.value
|
148 |
|
149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|