KevinHuSh commited on
Commit
3198faf
·
1 Parent(s): 3079197

add alot of api (#23)

Browse files

* clean rust version project

* clean rust version project

* build python version rag-flow

* add alot of api

rag/llm/embedding_model.py CHANGED
@@ -35,7 +35,7 @@ class Base(ABC):
35
 
36
 
37
  class HuEmbedding(Base):
38
- def __init__(self):
39
  """
40
  If you have trouble downloading HuggingFace models, -_^ this might help!!
41
 
 
35
 
36
 
37
  class HuEmbedding(Base):
38
+ def __init__(self, key="", model_name=""):
39
  """
40
  If you have trouble downloading HuggingFace models, -_^ this might help!!
41
 
rag/nlp/huchunk.py CHANGED
@@ -411,9 +411,12 @@ class TextChunker(HuChunker):
411
  flds = self.Fields()
412
  if self.is_binary_file(fnm):
413
  return flds
414
- with open(fnm, "r") as f:
415
- txt = f.read()
416
- flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
 
 
 
417
  flds.table_chunks = []
418
  return flds
419
 
 
411
  flds = self.Fields()
412
  if self.is_binary_file(fnm):
413
  return flds
414
+ txt = ""
415
+ if isinstance(fnm, str):
416
+ with open(fnm, "r") as f:
417
+ txt = f.read()
418
+ else: txt = fnm.decode("utf-8")
419
+ flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
420
  flds.table_chunks = []
421
  return flds
422
 
rag/nlp/search.py CHANGED
@@ -8,7 +8,7 @@ from rag.nlp import huqie, query
8
  import numpy as np
9
 
10
 
11
- def index_name(uid): return f"docgpt_{uid}"
12
 
13
 
14
  class Dealer:
 
8
  import numpy as np
9
 
10
 
11
+ def index_name(uid): return f"ragflow_{uid}"
12
 
13
 
14
  class Dealer:
rag/svr/parse_user_docs.py CHANGED
@@ -14,6 +14,7 @@
14
  # limitations under the License.
15
  #
16
  import json
 
17
  import os
18
  import hashlib
19
  import copy
@@ -24,9 +25,10 @@ from timeit import default_timer as timer
24
 
25
  from rag.llm import EmbeddingModel, CvModel
26
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
27
- from rag.utils import ELASTICSEARCH, num_tokens_from_string
28
  from rag.utils import MINIO
29
- from rag.utils import rmSpace, findMaxDt
 
30
  from rag.nlp import huchunk, huqie, search
31
  from io import BytesIO
32
  import pandas as pd
@@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
47
  from web_server.db import LLMType
48
  from web_server.db.services.document_service import DocumentService
49
  from web_server.db.services.llm_service import TenantLLMService
 
50
  from web_server.utils import get_format_time
51
  from web_server.utils.file_utils import get_project_base_directory
52
 
@@ -83,7 +86,7 @@ def collect(comm, mod, tm):
83
  if len(docs) == 0:
84
  return pd.DataFrame()
85
  docs = pd.DataFrame(docs)
86
- mtm = str(docs["update_time"].max())[:19]
87
  cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
88
  return docs
89
 
@@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
99
  cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
100
 
101
 
102
- def build(row):
103
  if row["size"] > DOC_MAXIMUM_SIZE:
104
  set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
105
  (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
106
  return []
 
107
  res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
108
  if ELASTICSEARCH.getTotal(res) > 0:
109
  ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
@@ -120,7 +124,8 @@ def build(row):
120
  set_progress(row["id"], random.randint(0, 20) /
121
  100., "Finished preparing! Start to slice file!", True)
122
  try:
123
- obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
 
124
  except Exception as e:
125
  if re.search("(No such file|not found)", str(e)):
126
  set_progress(
@@ -131,6 +136,9 @@ def build(row):
131
  row["id"], -1, f"Internal server error: %s" %
132
  str(e).replace(
133
  "'", ""))
 
 
 
134
  return []
135
 
136
  if not obj.text_chunks and not obj.table_chunks:
@@ -144,7 +152,7 @@ def build(row):
144
  "Finished slicing files. Start to embedding the content.")
145
 
146
  doc = {
147
- "doc_id": row["did"],
148
  "kb_id": [str(row["kb_id"])],
149
  "docnm_kwd": os.path.split(row["location"])[-1],
150
  "title_tks": huqie.qie(row["name"]),
@@ -164,10 +172,10 @@ def build(row):
164
  docs.append(d)
165
  continue
166
 
167
- if isinstance(img, Image):
168
- img.save(output_buffer, format='JPEG')
169
- else:
170
  output_buffer = BytesIO(img)
 
 
171
 
172
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
173
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
@@ -215,15 +223,16 @@ def embedding(docs, mdl):
215
 
216
 
217
  def model_instance(tenant_id, llm_type):
218
- model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
219
- if not model_config:return
220
- model_config = model_config[0]
 
221
  if llm_type == LLMType.EMBEDDING:
222
- if model_config.llm_factory not in EmbeddingModel: return
223
- return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
224
  if llm_type == LLMType.IMAGE2TEXT:
225
- if model_config.llm_factory not in CvModel: return
226
- return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
227
 
228
 
229
  def main(comm, mod):
@@ -231,7 +240,7 @@ def main(comm, mod):
231
  from rag.llm import HuEmbedding
232
  model = HuEmbedding()
233
  tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
234
- tm = findMaxDt(tm_fnm)
235
  rows = collect(comm, mod, tm)
236
  if len(rows) == 0:
237
  return
@@ -247,7 +256,7 @@ def main(comm, mod):
247
  st_tm = timer()
248
  cks = build(r, cv_mdl)
249
  if not cks:
250
- tmf.write(str(r["updated_at"]) + "\n")
251
  continue
252
  # TODO: exception handler
253
  ## set_progress(r["did"], -1, "ERROR: ")
@@ -268,12 +277,19 @@ def main(comm, mod):
268
  cron_logger.error(str(es_r))
269
  else:
270
  set_progress(r["id"], 1., "Done!")
271
- DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
 
 
272
  tmf.write(str(r["update_time"]) + "\n")
273
  tmf.close()
274
 
275
 
276
  if __name__ == "__main__":
 
 
 
 
 
277
  from mpi4py import MPI
278
  comm = MPI.COMM_WORLD
279
  main(comm.Get_size(), comm.Get_rank())
 
14
  # limitations under the License.
15
  #
16
  import json
17
+ import logging
18
  import os
19
  import hashlib
20
  import copy
 
25
 
26
  from rag.llm import EmbeddingModel, CvModel
27
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
28
+ from rag.utils import ELASTICSEARCH
29
  from rag.utils import MINIO
30
+ from rag.utils import rmSpace, findMaxTm
31
+
32
  from rag.nlp import huchunk, huqie, search
33
  from io import BytesIO
34
  import pandas as pd
 
49
  from web_server.db import LLMType
50
  from web_server.db.services.document_service import DocumentService
51
  from web_server.db.services.llm_service import TenantLLMService
52
+ from web_server.settings import database_logger
53
  from web_server.utils import get_format_time
54
  from web_server.utils.file_utils import get_project_base_directory
55
 
 
86
  if len(docs) == 0:
87
  return pd.DataFrame()
88
  docs = pd.DataFrame(docs)
89
+ mtm = docs["update_time"].max()
90
  cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
91
  return docs
92
 
 
102
  cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
103
 
104
 
105
+ def build(row, cvmdl):
106
  if row["size"] > DOC_MAXIMUM_SIZE:
107
  set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
108
  (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
109
  return []
110
+
111
  res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
112
  if ELASTICSEARCH.getTotal(res) > 0:
113
  ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
 
124
  set_progress(row["id"], random.randint(0, 20) /
125
  100., "Finished preparing! Start to slice file!", True)
126
  try:
127
+ cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
128
+ obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
129
  except Exception as e:
130
  if re.search("(No such file|not found)", str(e)):
131
  set_progress(
 
136
  row["id"], -1, f"Internal server error: %s" %
137
  str(e).replace(
138
  "'", ""))
139
+
140
+ cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
141
+
142
  return []
143
 
144
  if not obj.text_chunks and not obj.table_chunks:
 
152
  "Finished slicing files. Start to embedding the content.")
153
 
154
  doc = {
155
+ "doc_id": row["id"],
156
  "kb_id": [str(row["kb_id"])],
157
  "docnm_kwd": os.path.split(row["location"])[-1],
158
  "title_tks": huqie.qie(row["name"]),
 
172
  docs.append(d)
173
  continue
174
 
175
+ if isinstance(img, bytes):
 
 
176
  output_buffer = BytesIO(img)
177
+ else:
178
+ img.save(output_buffer, format='JPEG')
179
 
180
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
181
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
 
223
 
224
 
225
  def model_instance(tenant_id, llm_type):
226
+ model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
227
+ if not model_config:
228
+ model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
229
+ else: model_config = model_config[0].to_dict()
230
  if llm_type == LLMType.EMBEDDING:
231
+ if model_config["llm_factory"] not in EmbeddingModel: return
232
+ return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
233
  if llm_type == LLMType.IMAGE2TEXT:
234
+ if model_config["llm_factory"] not in CvModel: return
235
+ return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
236
 
237
 
238
  def main(comm, mod):
 
240
  from rag.llm import HuEmbedding
241
  model = HuEmbedding()
242
  tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
243
+ tm = findMaxTm(tm_fnm)
244
  rows = collect(comm, mod, tm)
245
  if len(rows) == 0:
246
  return
 
256
  st_tm = timer()
257
  cks = build(r, cv_mdl)
258
  if not cks:
259
+ tmf.write(str(r["update_time"]) + "\n")
260
  continue
261
  # TODO: exception handler
262
  ## set_progress(r["did"], -1, "ERROR: ")
 
277
  cron_logger.error(str(es_r))
278
  else:
279
  set_progress(r["id"], 1., "Done!")
280
+ DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
281
+ cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
282
+
283
  tmf.write(str(r["update_time"]) + "\n")
284
  tmf.close()
285
 
286
 
287
  if __name__ == "__main__":
288
+ peewee_logger = logging.getLogger('peewee')
289
+ peewee_logger.propagate = False
290
+ peewee_logger.addHandler(database_logger.handlers[0])
291
+ peewee_logger.setLevel(database_logger.level)
292
+
293
  from mpi4py import MPI
294
  comm = MPI.COMM_WORLD
295
  main(comm.Get_size(), comm.Get_rank())
rag/utils/__init__.py CHANGED
@@ -40,6 +40,25 @@ def findMaxDt(fnm):
40
  print("WARNING: can't find " + fnm)
41
  return m
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def num_tokens_from_string(string: str) -> int:
44
  """Returns the number of tokens in a text string."""
45
  encoding = tiktoken.get_encoding('cl100k_base')
 
40
  print("WARNING: can't find " + fnm)
41
  return m
42
 
43
+
44
+ def findMaxTm(fnm):
45
+ m = 0
46
+ try:
47
+ with open(fnm, "r") as f:
48
+ while True:
49
+ l = f.readline()
50
+ if not l:
51
+ break
52
+ l = l.strip("\n")
53
+ if l == 'nan':
54
+ continue
55
+ if int(l) > m:
56
+ m = int(l)
57
+ except Exception as e:
58
+ print("WARNING: can't find " + fnm)
59
+ return m
60
+
61
+
62
  def num_tokens_from_string(string: str) -> int:
63
  """Returns the number of tokens in a text string."""
64
  encoding = tiktoken.get_encoding('cl100k_base')
rag/utils/es_conn.py CHANGED
@@ -294,6 +294,7 @@ class HuEs:
294
  except Exception as e:
295
  es_logger.error("ES updateByQuery deleteByQuery: " +
296
  str(e) + "【Q】:" + str(query.to_dict()))
 
297
  if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
298
  continue
299
 
 
294
  except Exception as e:
295
  es_logger.error("ES updateByQuery deleteByQuery: " +
296
  str(e) + "【Q】:" + str(query.to_dict()))
297
+ if str(e).find("NotFoundError") > 0: return True
298
  if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
299
  continue
300
 
web_server/apps/document_app.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import pathlib
17
 
18
  from elasticsearch_dsl import Q
@@ -195,11 +196,15 @@ def rm():
195
  e, doc = DocumentService.get_by_id(req["doc_id"])
196
  if not e:
197
  return get_data_error_result(retmsg="Document not found!")
 
 
 
 
198
  if not DocumentService.delete_by_id(req["doc_id"]):
199
  return get_data_error_result(
200
  retmsg="Database error (Document removal)!")
201
- e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
202
- MINIO.rm(kb.id, doc.location)
203
  return get_json_result(data=True)
204
  except Exception as e:
205
  return server_error_response(e)
@@ -233,3 +238,43 @@ def rename():
233
  return get_json_result(data=True)
234
  except Exception as e:
235
  return server_error_response(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import base64
17
  import pathlib
18
 
19
  from elasticsearch_dsl import Q
 
196
  e, doc = DocumentService.get_by_id(req["doc_id"])
197
  if not e:
198
  return get_data_error_result(retmsg="Document not found!")
199
+ if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
200
+ return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
201
+
202
+ DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
203
  if not DocumentService.delete_by_id(req["doc_id"]):
204
  return get_data_error_result(
205
  retmsg="Database error (Document removal)!")
206
+
207
+ MINIO.rm(doc.kb_id, doc.location)
208
  return get_json_result(data=True)
209
  except Exception as e:
210
  return server_error_response(e)
 
238
  return get_json_result(data=True)
239
  except Exception as e:
240
  return server_error_response(e)
241
+
242
+
243
+ @manager.route('/get', methods=['GET'])
244
+ @login_required
245
+ def get():
246
+ doc_id = request.args["doc_id"]
247
+ try:
248
+ e, doc = DocumentService.get_by_id(doc_id)
249
+ if not e:
250
+ return get_data_error_result(retmsg="Document not found!")
251
+
252
+ blob = MINIO.get(doc.kb_id, doc.location)
253
+ return get_json_result(data={"base64": base64.b64decode(blob)})
254
+ except Exception as e:
255
+ return server_error_response(e)
256
+
257
+
258
+ @manager.route('/change_parser', methods=['POST'])
259
+ @login_required
260
+ @validate_request("doc_id", "parser_id")
261
+ def change_parser():
262
+ req = request.json
263
+ try:
264
+ e, doc = DocumentService.get_by_id(req["doc_id"])
265
+ if not e:
266
+ return get_data_error_result(retmsg="Document not found!")
267
+ if doc.parser_id.lower() == req["parser_id"].lower():
268
+ return get_json_result(data=True)
269
+
270
+ e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
271
+ if not e:
272
+ return get_data_error_result(retmsg="Document not found!")
273
+ e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
274
+ if not e:
275
+ return get_data_error_result(retmsg="Document not found!")
276
+
277
+ return get_json_result(data=True)
278
+ except Exception as e:
279
+ return server_error_response(e)
280
+
web_server/apps/kb_app.py CHANGED
@@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
29
 
30
  @manager.route('/create', methods=['post'])
31
  @login_required
32
- @validate_request("name", "description", "permission", "embd_id", "parser_id")
33
  def create():
34
  req = request.json
35
  req["name"] = req["name"].strip()
@@ -46,7 +46,7 @@ def create():
46
 
47
  @manager.route('/update', methods=['post'])
48
  @login_required
49
- @validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
50
  def update():
51
  req = request.json
52
  req["name"] = req["name"].strip()
@@ -72,6 +72,18 @@ def update():
72
  return server_error_response(e)
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @manager.route('/list', methods=['GET'])
76
  @login_required
77
  def list():
 
29
 
30
  @manager.route('/create', methods=['post'])
31
  @login_required
32
+ @validate_request("name", "description", "permission", "parser_id")
33
  def create():
34
  req = request.json
35
  req["name"] = req["name"].strip()
 
46
 
47
  @manager.route('/update', methods=['post'])
48
  @login_required
49
+ @validate_request("kb_id", "name", "description", "permission", "parser_id")
50
  def update():
51
  req = request.json
52
  req["name"] = req["name"].strip()
 
72
  return server_error_response(e)
73
 
74
 
75
+ @manager.route('/detail', methods=['GET'])
76
+ @login_required
77
+ def detail():
78
+ kb_id = request.args["kb_id"]
79
+ try:
80
+ kb = KnowledgebaseService.get_detail(kb_id)
81
+ if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
82
+ return get_json_result(data=kb)
83
+ except Exception as e:
84
+ return server_error_response(e)
85
+
86
+
87
  @manager.route('/list', methods=['GET'])
88
  @login_required
89
  def list():
web_server/apps/llm_app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2019 The FATE Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from flask import request
17
+ from flask_login import login_required, current_user
18
+
19
+ from web_server.db.services import duplicate_name
20
+ from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
21
+ from web_server.db.services.user_service import TenantService, UserTenantService
22
+ from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
23
+ from web_server.utils import get_uuid, get_format_time
24
+ from web_server.db import StatusEnum, UserTenantRole
25
+ from web_server.db.services.kb_service import KnowledgebaseService
26
+ from web_server.db.db_models import Knowledgebase, TenantLLM
27
+ from web_server.settings import stat_logger, RetCode
28
+ from web_server.utils.api_utils import get_json_result
29
+
30
+
31
+ @manager.route('/factories', methods=['GET'])
32
+ @login_required
33
+ def factories():
34
+ try:
35
+ fac = LLMFactoriesService.get_all()
36
+ return get_json_result(data=fac.to_json())
37
+ except Exception as e:
38
+ return server_error_response(e)
39
+
40
+
41
+ @manager.route('/set_api_key', methods=['POST'])
42
+ @login_required
43
+ @validate_request("llm_factory", "api_key")
44
+ def set_api_key():
45
+ req = request.json
46
+ llm = {
47
+ "tenant_id": current_user.id,
48
+ "llm_factory": req["llm_factory"],
49
+ "api_key": req["api_key"]
50
+ }
51
+ # TODO: Test api_key
52
+ for n in ["model_type", "llm_name"]:
53
+ if n in req: llm[n] = req[n]
54
+
55
+ TenantLLM.insert(**llm).on_conflict("replace").execute()
56
+ return get_json_result(data=True)
57
+
58
+
59
+ @manager.route('/my_llms', methods=['GET'])
60
+ @login_required
61
+ def my_llms():
62
+ try:
63
+ objs = TenantLLMService.query(tenant_id=current_user.id)
64
+ objs = [o.to_dict() for o in objs]
65
+ for o in objs: del o["api_key"]
66
+ return get_json_result(data=objs)
67
+ except Exception as e:
68
+ return server_error_response(e)
69
+
70
+
71
+ @manager.route('/list', methods=['GET'])
72
+ @login_required
73
+ def list():
74
+ try:
75
+ objs = TenantLLMService.query(tenant_id=current_user.id)
76
+ objs = [o.to_dict() for o in objs if o.api_key]
77
+ fct = {}
78
+ for o in objs:
79
+ if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
80
+ if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
81
+
82
+ llms = LLMService.get_all()
83
+ llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
84
+ for m in llms:
85
+ m["available"] = False
86
+ if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
87
+ m["available"] = True
88
+ res = {}
89
+ for m in llms:
90
+ if m["fid"] not in res: res[m["fid"]] = []
91
+ res[m["fid"]].append(m)
92
+
93
+ return get_json_result(data=res)
94
+ except Exception as e:
95
+ return server_error_response(e)
web_server/apps/user_app.py CHANGED
@@ -16,9 +16,12 @@
16
  from flask import request, session, redirect, url_for
17
  from werkzeug.security import generate_password_hash, check_password_hash
18
  from flask_login import login_required, current_user, login_user, logout_user
 
 
 
19
  from web_server.utils.api_utils import server_error_response, validate_request
20
  from web_server.utils import get_uuid, get_format_time, decrypt, download_img
21
- from web_server.db import UserTenantRole
22
  from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
23
  from web_server.db.services.user_service import UserService, TenantService, UserTenantService
24
  from web_server.settings import stat_logger
@@ -47,8 +50,9 @@ def login():
47
  avatar = download_img(userinfo["avatar_url"])
48
  except Exception as e:
49
  stat_logger.exception(e)
 
50
  try:
51
- users = user_register({
52
  "access_token": session["access_token"],
53
  "email": userinfo["email"],
54
  "avatar": avatar,
@@ -63,6 +67,7 @@ def login():
63
  login_user(user)
64
  return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
65
  except Exception as e:
 
66
  stat_logger.exception(e)
67
  return server_error_response(e)
68
  elif not request.json:
@@ -162,7 +167,25 @@ def user_info():
162
  return get_json_result(data=current_user.to_dict())
163
 
164
 
165
- def user_register(user):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  user_id = get_uuid()
167
  user["id"] = user_id
168
  tenant = {
@@ -180,10 +203,12 @@ def user_register(user):
180
  "invited_by": user_id,
181
  "role": UserTenantRole.OWNER
182
  }
 
183
 
184
  if not UserService.save(**user):return
185
  TenantService.save(**tenant)
186
  UserTenantService.save(**usr_tenant)
 
187
  return UserService.query(email=user["email"])
188
 
189
 
@@ -203,14 +228,17 @@ def user_add():
203
  "last_login_time": get_format_time(),
204
  "is_superuser": False,
205
  }
 
 
206
  try:
207
- users = user_register(user_dict)
208
  if not users: raise Exception('Register user failure.')
209
  if len(users) > 1: raise Exception('Same E-mail exist!')
210
  user = users[0]
211
  login_user(user)
212
  return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
213
  except Exception as e:
 
214
  stat_logger.exception(e)
215
  return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
216
 
@@ -220,7 +248,7 @@ def user_add():
220
  @login_required
221
  def tenant_info():
222
  try:
223
- tenants = TenantService.get_by_user_id(current_user.id)
224
  return get_json_result(data=tenants)
225
  except Exception as e:
226
  return server_error_response(e)
 
16
  from flask import request, session, redirect, url_for
17
  from werkzeug.security import generate_password_hash, check_password_hash
18
  from flask_login import login_required, current_user, login_user, logout_user
19
+
20
+ from web_server.db.db_models import TenantLLM
21
+ from web_server.db.services.llm_service import TenantLLMService
22
  from web_server.utils.api_utils import server_error_response, validate_request
23
  from web_server.utils import get_uuid, get_format_time, decrypt, download_img
24
+ from web_server.db import UserTenantRole, LLMType
25
  from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
26
  from web_server.db.services.user_service import UserService, TenantService, UserTenantService
27
  from web_server.settings import stat_logger
 
50
  avatar = download_img(userinfo["avatar_url"])
51
  except Exception as e:
52
  stat_logger.exception(e)
53
+ user_id = get_uuid()
54
  try:
55
+ users = user_register(user_id, {
56
  "access_token": session["access_token"],
57
  "email": userinfo["email"],
58
  "avatar": avatar,
 
67
  login_user(user)
68
  return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
69
  except Exception as e:
70
+ rollback_user_registration(user_id)
71
  stat_logger.exception(e)
72
  return server_error_response(e)
73
  elif not request.json:
 
167
  return get_json_result(data=current_user.to_dict())
168
 
169
 
170
+ def rollback_user_registration(user_id):
171
+ try:
172
+ TenantService.delete_by_id(user_id)
173
+ except Exception as e:
174
+ pass
175
+ try:
176
+ u = UserTenantService.query(tenant_id=user_id)
177
+ if u:
178
+ UserTenantService.delete_by_id(u[0].id)
179
+ except Exception as e:
180
+ pass
181
+ try:
182
+ TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
183
+ except Exception as e:
184
+ pass
185
+
186
+
187
+ def user_register(user_id, user):
188
+
189
  user_id = get_uuid()
190
  user["id"] = user_id
191
  tenant = {
 
203
  "invited_by": user_id,
204
  "role": UserTenantRole.OWNER
205
  }
206
+ tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
207
 
208
  if not UserService.save(**user):return
209
  TenantService.save(**tenant)
210
  UserTenantService.save(**usr_tenant)
211
+ TenantLLMService.save(**tenant_llm)
212
  return UserService.query(email=user["email"])
213
 
214
 
 
228
  "last_login_time": get_format_time(),
229
  "is_superuser": False,
230
  }
231
+
232
+ user_id = get_uuid()
233
  try:
234
+ users = user_register(user_id, user_dict)
235
  if not users: raise Exception('Register user failure.')
236
  if len(users) > 1: raise Exception('Same E-mail exist!')
237
  user = users[0]
238
  login_user(user)
239
  return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
240
  except Exception as e:
241
+ rollback_user_registration(user_id)
242
  stat_logger.exception(e)
243
  return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
244
 
 
248
  @login_required
249
  def tenant_info():
250
  try:
251
+ tenants = TenantService.get_by_user_id(current_user.id)[0]
252
  return get_json_result(data=tenants)
253
  except Exception as e:
254
  return server_error_response(e)
web_server/db/db_models.py CHANGED
@@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
428
  class LLM(DataBaseModel):
429
  # defautlt LLMs for every users
430
  llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
 
431
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
432
  tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
433
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
@@ -442,8 +443,8 @@ class LLM(DataBaseModel):
442
  class TenantLLM(DataBaseModel):
443
  tenant_id = CharField(max_length=32, null=False)
444
  llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
445
- model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
446
- llm_name = CharField(max_length=128, null=False, help_text="LLM name")
447
  api_key = CharField(max_length=255, null=True, help_text="API KEY")
448
  api_base = CharField(max_length=255, null=True, help_text="API Base")
449
 
@@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
452
 
453
  class Meta:
454
  db_table = "tenant_llm"
455
- primary_key = CompositeKey('tenant_id', 'llm_factory')
456
 
457
 
458
  class Knowledgebase(DataBaseModel):
@@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel):
464
  permission = CharField(max_length=16, null=False, help_text="me|team")
465
  created_by = CharField(max_length=32, null=False)
466
  doc_num = IntegerField(default=0)
467
- embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
 
 
468
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
469
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
470
 
 
428
  class LLM(DataBaseModel):
429
  # defautlt LLMs for every users
430
  llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
431
+ model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
432
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
433
  tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
434
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
443
  class TenantLLM(DataBaseModel):
444
  tenant_id = CharField(max_length=32, null=False)
445
  llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
446
+ model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
447
+ llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
448
  api_key = CharField(max_length=255, null=True, help_text="API KEY")
449
  api_base = CharField(max_length=255, null=True, help_text="API Base")
450
 
 
453
 
454
  class Meta:
455
  db_table = "tenant_llm"
456
+ primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
457
 
458
 
459
  class Knowledgebase(DataBaseModel):
 
465
  permission = CharField(max_length=16, null=False, help_text="me|team")
466
  created_by = CharField(max_length=32, null=False)
467
  doc_num = IntegerField(default=0)
468
+ token_num = IntegerField(default=0)
469
+ chunk_num = IntegerField(default=0)
470
+
471
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
472
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
473
 
web_server/db/services/document_service.py CHANGED
@@ -13,12 +13,13 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  from web_server.db import TenantPermission, FileType
17
- from web_server.db.db_models import DB, Knowledgebase
18
  from web_server.db.db_models import Document
19
  from web_server.db.services.common_service import CommonService
20
  from web_server.db.services.kb_service import KnowledgebaseService
21
- from web_server.utils import get_uuid, get_format_time
22
  from web_server.db.db_utils import StatusEnum
23
 
24
 
@@ -61,15 +62,28 @@ class DocumentService(CommonService):
61
  @classmethod
62
  @DB.connection_context()
63
  def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
64
- fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
65
- docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where(
66
- cls.model.status == StatusEnum.VALID.value,
67
- cls.model.type != FileType.VIRTUAL,
68
- cls.model.progress == 0,
69
- cls.model.update_time >= tm,
70
- cls.model.create_time %
71
- comm == mod).order_by(
72
- cls.model.update_time.asc()).paginate(
73
- 1,
74
- items_per_page)
 
75
  return list(docs.dicts())
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from peewee import Expression
17
+
18
  from web_server.db import TenantPermission, FileType
19
+ from web_server.db.db_models import DB, Knowledgebase, Tenant
20
  from web_server.db.db_models import Document
21
  from web_server.db.services.common_service import CommonService
22
  from web_server.db.services.kb_service import KnowledgebaseService
 
23
  from web_server.db.db_utils import StatusEnum
24
 
25
 
 
62
  @classmethod
63
  @DB.connection_context()
64
  def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
65
+ fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
66
+ docs = cls.model.select(*fields) \
67
+ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
68
+ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
69
+ .where(
70
+ cls.model.status == StatusEnum.VALID.value,
71
+ ~(cls.model.type == FileType.VIRTUAL.value),
72
+ cls.model.progress == 0,
73
+ cls.model.update_time >= tm,
74
+ (Expression(cls.model.create_time, "%%", comm) == mod))\
75
+ .order_by(cls.model.update_time.asc())\
76
+ .paginate(1, items_per_page)
77
  return list(docs.dicts())
78
+
79
+ @classmethod
80
+ @DB.connection_context()
81
+ def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
82
+ num = cls.model.update(token_num=cls.model.token_num + token_num,
83
+ chunk_num=cls.model.chunk_num + chunk_num,
84
+ process_duation=cls.model.process_duation+duation).where(
85
+ cls.model.id == doc_id).execute()
86
+ if num == 0:raise LookupError("Document not found which is supposed to be there")
87
+ num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
88
+ return num
89
+
web_server/db/services/kb_service.py CHANGED
@@ -17,7 +17,7 @@ import peewee
17
  from werkzeug.security import generate_password_hash, check_password_hash
18
 
19
  from web_server.db import TenantPermission
20
- from web_server.db.db_models import DB, UserTenant
21
  from web_server.db.db_models import Knowledgebase
22
  from web_server.db.services.common_service import CommonService
23
  from web_server.utils import get_uuid, get_format_time
@@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
29
 
30
  @classmethod
31
  @DB.connection_context()
32
- def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc):
 
33
  kbs = cls.model.select().where(
34
- ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
35
- & (cls.model.status==StatusEnum.VALID.value)
 
36
  )
37
- if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
38
- else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
 
 
39
 
40
  kbs = kbs.paginate(page_number, items_per_page)
41
 
42
  return list(kbs.dicts())
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from werkzeug.security import generate_password_hash, check_password_hash
18
 
19
  from web_server.db import TenantPermission
20
+ from web_server.db.db_models import DB, UserTenant, Tenant
21
  from web_server.db.db_models import Knowledgebase
22
  from web_server.db.services.common_service import CommonService
23
  from web_server.utils import get_uuid, get_format_time
 
29
 
30
  @classmethod
31
  @DB.connection_context()
32
+ def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
33
+ page_number, items_per_page, orderby, desc):
34
  kbs = cls.model.select().where(
35
+ ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
36
+ TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
37
+ & (cls.model.status == StatusEnum.VALID.value)
38
  )
39
+ if desc:
40
+ kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
41
+ else:
42
+ kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
43
 
44
  kbs = kbs.paginate(page_number, items_per_page)
45
 
46
  return list(kbs.dicts())
47
 
48
+ @classmethod
49
+ @DB.connection_context()
50
+ def get_detail(cls, kb_id):
51
+ fields = [
52
+ cls.model.id,
53
+ Tenant.embd_id,
54
+ cls.model.avatar,
55
+ cls.model.name,
56
+ cls.model.description,
57
+ cls.model.permission,
58
+ cls.model.doc_num,
59
+ cls.model.token_num,
60
+ cls.model.chunk_num,
61
+ cls.model.parser_id]
62
+ kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
63
+ (cls.model.id == kb_id),
64
+ (cls.model.status == StatusEnum.VALID.value)
65
+ )
66
+ if not kbs:
67
+ return
68
+ d = kbs[0].to_dict()
69
+ d["embd_id"] = kbs[0].tenant.embd_id
70
+ return d
web_server/db/services/llm_service.py CHANGED
@@ -33,3 +33,21 @@ class LLMService(CommonService):
33
 
34
  class TenantLLMService(CommonService):
35
  model = TenantLLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  class TenantLLMService(CommonService):
35
  model = TenantLLM
36
+
37
+ @classmethod
38
+ @DB.connection_context()
39
+ def get_api_key(cls, tenant_id, model_type):
40
+ objs = cls.query(tenant_id=tenant_id, model_type=model_type)
41
+ if objs and len(objs)>0 and objs[0].llm_name:
42
+ return objs[0]
43
+
44
+ fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
45
+ objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
46
+ (cls.model.tenant_id == tenant_id),
47
+ (cls.model.model_type == model_type),
48
+ (LLM.status == StatusEnum.VALID)
49
+ )
50
+
51
+ if not objs:return
52
+ return objs[0]
53
+
web_server/db/services/user_service.py CHANGED
@@ -79,7 +79,7 @@ class TenantService(CommonService):
79
  @classmethod
80
  @DB.connection_context()
81
  def get_by_user_id(cls, user_id):
82
- fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
83
  return list(cls.model.select(*fields)\
84
  .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
85
  .where(cls.model.status == StatusEnum.VALID.value).dicts())
 
79
  @classmethod
80
  @DB.connection_context()
81
  def get_by_user_id(cls, user_id):
82
+ fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
83
  return list(cls.model.select(*fields)\
84
  .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
85
  .where(cls.model.status == StatusEnum.VALID.value).dicts())
web_server/utils/file_utils.py CHANGED
@@ -143,7 +143,7 @@ def filename_type(filename):
143
  if re.match(r".*\.pdf$", filename):
144
  return FileType.PDF.value
145
 
146
- if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
147
  return FileType.DOC.value
148
 
149
  if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
 
143
  if re.match(r".*\.pdf$", filename):
144
  return FileType.PDF.value
145
 
146
+ if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
147
  return FileType.DOC.value
148
 
149
  if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):