KevinHuSh commited on
Commit
f666f56
·
1 Parent(s): 4873964

fix user login issue (#85)

Browse files
api/apps/user_app.py CHANGED
@@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
33
 
34
  @manager.route('/login', methods=['POST', 'GET'])
35
  def login():
36
- userinfo = None
37
  login_channel = "password"
38
- if session.get("access_token"):
39
- login_channel = session["access_token_from"]
40
- if session["access_token_from"] == "github":
41
- userinfo = user_info_from_github(session["access_token"])
42
- elif not request.json:
43
  return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
44
  retmsg='Unautherized!')
45
 
46
- email = request.json.get('email') if not userinfo else userinfo["email"]
47
  users = UserService.query(email=email)
48
- if not users:
49
- if request.json is not None:
50
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
51
- avatar = ""
52
- try:
53
- avatar = download_img(userinfo["avatar_url"])
54
- except Exception as e:
55
- stat_logger.exception(e)
56
- user_id = get_uuid()
57
- try:
58
- users = user_register(user_id, {
59
- "access_token": session["access_token"],
60
- "email": userinfo["email"],
61
- "avatar": avatar,
62
- "nickname": userinfo["login"],
63
- "login_channel": login_channel,
64
- "last_login_time": get_format_time(),
65
- "is_superuser": False,
66
- })
67
- if not users: raise Exception('Register user failure.')
68
- if len(users) > 1: raise Exception('Same E-mail exist!')
69
- user = users[0]
70
- login_user(user)
71
- return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
72
- except Exception as e:
73
- rollback_user_registration(user_id)
74
- stat_logger.exception(e)
75
- return server_error_response(e)
76
- elif not request.json:
77
- login_user(users[0])
78
- return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
79
 
80
  password = request.json.get('password')
81
  try:
@@ -97,28 +62,50 @@ def login():
97
 
98
  @manager.route('/github_callback', methods=['GET'])
99
  def github_callback():
100
- try:
101
- import requests
102
- res = requests.post(GITHUB_OAUTH.get("url"), data={
103
- "client_id": GITHUB_OAUTH.get("client_id"),
104
- "client_secret": GITHUB_OAUTH.get("secret_key"),
105
- "code": request.args.get('code')
106
- },headers={"Accept": "application/json"})
107
- res = res.json()
108
- if "error" in res:
109
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
110
- retmsg=res["error_description"])
111
-
112
- if "user:email" not in res["scope"].split(","):
113
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
114
-
115
- session["access_token"] = res["access_token"]
116
- session["access_token_from"] = "github"
117
- return redirect(url_for("user.login"), code=307)
118
 
119
- except Exception as e:
120
- stat_logger.exception(e)
121
- return server_error_response(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  def user_info_from_github(access_token):
@@ -208,7 +195,7 @@ def user_register(user_id, user):
208
  for llm in LLMService.query(fid=LLM_FACTORY):
209
  tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
210
 
211
- if not UserService.insert(**user):return
212
  TenantService.insert(**tenant)
213
  UserTenantService.insert(**usr_tenant)
214
  TenantLLMService.insert_many(tenant_llm)
 
33
 
34
  @manager.route('/login', methods=['POST', 'GET'])
35
  def login():
 
36
  login_channel = "password"
37
+ if not request.json:
 
 
 
 
38
  return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
39
  retmsg='Unautherized!')
40
 
41
+ email = request.json.get('email', "")
42
  users = UserService.query(email=email)
43
+ if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  password = request.json.get('password')
46
  try:
 
62
 
63
  @manager.route('/github_callback', methods=['GET'])
64
  def github_callback():
65
+ import requests
66
+ res = requests.post(GITHUB_OAUTH.get("url"), data={
67
+ "client_id": GITHUB_OAUTH.get("client_id"),
68
+ "client_secret": GITHUB_OAUTH.get("secret_key"),
69
+ "code": request.args.get('code')
70
+ }, headers={"Accept": "application/json"})
71
+ res = res.json()
72
+ if "error" in res:
73
+ return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
74
+ retmsg=res["error_description"])
 
 
 
 
 
 
 
 
75
 
76
+ if "user:email" not in res["scope"].split(","):
77
+ return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
78
+
79
+ session["access_token"] = res["access_token"]
80
+ session["access_token_from"] = "github"
81
+ userinfo = user_info_from_github(session["access_token"])
82
+ users = UserService.query(email=userinfo["email"])
83
+ user_id = get_uuid()
84
+ if not users:
85
+ try:
86
+ try:
87
+ avatar = download_img(userinfo["avatar_url"])
88
+ except Exception as e:
89
+ stat_logger.exception(e)
90
+ avatar = ""
91
+ users = user_register(user_id, {
92
+ "access_token": session["access_token"],
93
+ "email": userinfo["email"],
94
+ "avatar": avatar,
95
+ "nickname": userinfo["login"],
96
+ "login_channel": "github",
97
+ "last_login_time": get_format_time(),
98
+ "is_superuser": False,
99
+ })
100
+ if not users: raise Exception('Register user failure.')
101
+ if len(users) > 1: raise Exception('Same E-mail exist!')
102
+ user = users[0]
103
+ login_user(user)
104
+ except Exception as e:
105
+ rollback_user_registration(user_id)
106
+ stat_logger.exception(e)
107
+
108
+ return redirect("/knowledge")
109
 
110
 
111
  def user_info_from_github(access_token):
 
195
  for llm in LLMService.query(fid=LLM_FACTORY):
196
  tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
197
 
198
+ if not UserService.save(**user):return
199
  TenantService.insert(**tenant)
200
  UserTenantService.insert(**usr_tenant)
201
  TenantLLMService.insert_many(tenant_llm)
api/db/__init__.py CHANGED
@@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
69
 
70
 
71
  class ParserType(StrEnum):
72
- GENERAL = "general"
73
  PRESENTATION = "presentation"
74
  LAWS = "laws"
75
  MANUAL = "manual"
 
69
 
70
 
71
  class ParserType(StrEnum):
 
72
  PRESENTATION = "presentation"
73
  LAWS = "laws"
74
  MANUAL = "manual"
api/db/db_models.py CHANGED
@@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
475
  similarity_threshold = FloatField(default=0.2)
476
  vector_similarity_weight = FloatField(default=0.3)
477
 
478
- parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
479
  parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
480
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
481
 
 
475
  similarity_threshold = FloatField(default=0.2)
476
  vector_similarity_weight = FloatField(default=0.3)
477
 
478
+ parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
479
  parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
480
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
481
 
api/db/init_data.py CHANGED
@@ -30,7 +30,7 @@ def init_superuser():
30
  "password": "admin",
31
  "nickname": "admin",
32
  "is_superuser": True,
33
- "email": "kai.hu@infiniflow.org",
34
  "creator": "system",
35
  "status": "1",
36
  }
@@ -61,7 +61,7 @@ def init_superuser():
61
  TenantService.insert(**tenant)
62
  UserTenantService.insert(**usr_tenant)
63
  TenantLLMService.insert_many(tenant_llm)
64
- print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
65
 
66
  chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
67
  msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
 
30
  "password": "admin",
31
  "nickname": "admin",
32
  "is_superuser": True,
33
+ "email": "admin@ragflow.io",
34
  "creator": "system",
35
  "status": "1",
36
  }
 
61
  TenantService.insert(**tenant)
62
  UserTenantService.insert(**usr_tenant)
63
  TenantLLMService.insert_many(tenant_llm)
64
+ print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
65
 
66
  chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
67
  msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
api/db/services/user_service.py CHANGED
@@ -13,6 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  import peewee
17
  from werkzeug.security import generate_password_hash, check_password_hash
18
 
@@ -20,7 +22,7 @@ from api.db import UserTenantRole
20
  from api.db.db_models import DB, UserTenant
21
  from api.db.db_models import User, Tenant
22
  from api.db.services.common_service import CommonService
23
- from api.utils import get_uuid, get_format_time
24
  from api.db import StatusEnum
25
 
26
 
@@ -53,6 +55,11 @@ class UserService(CommonService):
53
  kwargs["id"] = get_uuid()
54
  if "password" in kwargs:
55
  kwargs["password"] = generate_password_hash(str(kwargs["password"]))
 
 
 
 
 
56
  obj = cls.model(**kwargs).save(force_insert=True)
57
  return obj
58
 
@@ -66,10 +73,10 @@ class UserService(CommonService):
66
  @classmethod
67
  @DB.connection_context()
68
  def update_user(cls, user_id, user_dict):
69
- date_time = get_format_time()
70
  with DB.atomic():
71
  if user_dict:
72
- user_dict["update_time"] = date_time
 
73
  cls.model.update(user_dict).where(cls.model.id == user_id).execute()
74
 
75
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from datetime import datetime
17
+
18
  import peewee
19
  from werkzeug.security import generate_password_hash, check_password_hash
20
 
 
22
  from api.db.db_models import DB, UserTenant
23
  from api.db.db_models import User, Tenant
24
  from api.db.services.common_service import CommonService
25
+ from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
26
  from api.db import StatusEnum
27
 
28
 
 
55
  kwargs["id"] = get_uuid()
56
  if "password" in kwargs:
57
  kwargs["password"] = generate_password_hash(str(kwargs["password"]))
58
+
59
+ kwargs["create_time"] = current_timestamp()
60
+ kwargs["create_date"] = datetime_format(datetime.now())
61
+ kwargs["update_time"] = current_timestamp()
62
+ kwargs["update_date"] = datetime_format(datetime.now())
63
  obj = cls.model(**kwargs).save(force_insert=True)
64
  return obj
65
 
 
73
  @classmethod
74
  @DB.connection_context()
75
  def update_user(cls, user_id, user_dict):
 
76
  with DB.atomic():
77
  if user_dict:
78
+ user_dict["update_time"] = current_timestamp()
79
+ user_dict["update_date"] = datetime_format(datetime.now())
80
  cls.model.update(user_dict).where(cls.model.id == user_id).execute()
81
 
82
 
api/settings.py CHANGED
@@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
76
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
77
 
78
  API_KEY = LLM.get("api_key", "infiniflow API Key")
79
- PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
80
 
81
  # distribution
82
  DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
 
76
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
77
 
78
  API_KEY = LLM.get("api_key", "infiniflow API Key")
79
+ PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
80
 
81
  # distribution
82
  DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
deepdoc/parser/pdf_parser.py CHANGED
@@ -25,7 +25,7 @@ class HuParser:
25
  def __init__(self):
26
  self.ocr = OCR()
27
  if not hasattr(self, "model_speciess"):
28
- self.model_speciess = ParserType.GENERAL.value
29
  self.layouter = LayoutRecognizer("layout."+self.model_speciess)
30
  self.tbl_det = TableStructureRecognizer()
31
 
 
25
  def __init__(self):
26
  self.ocr = OCR()
27
  if not hasattr(self, "model_speciess"):
28
+ self.model_speciess = ParserType.NAIVE.value
29
  self.layouter = LayoutRecognizer("layout."+self.model_speciess)
30
  self.tbl_det = TableStructureRecognizer()
31
 
deepdoc/vision/layout_recognizer.py CHANGED
@@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer):
34
  "Equation",
35
  ]
36
  def __init__(self, domain):
37
- super().__init__(self.labels, domain,
38
- os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
39
 
40
  def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
41
  def __is_garbage(b):
 
34
  "Equation",
35
  ]
36
  def __init__(self, domain):
37
+ super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
38
 
39
  def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
40
  def __is_garbage(b):
deepdoc/vision/table_structure_recognizer.py CHANGED
@@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer):
33
  ]
34
 
35
  def __init__(self):
36
- super().__init__(self.labels, "tsr",
37
- os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
38
 
39
  def __call__(self, images, thr=0.2):
40
  tbls = super().__call__(images, thr)
 
33
  ]
34
 
35
  def __init__(self):
36
+ super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
 
37
 
38
  def __call__(self, images, thr=0.2):
39
  tbls = super().__call__(images, thr)
rag/app/manual.py CHANGED
@@ -1,11 +1,17 @@
1
  import copy
2
  import re
 
 
3
  from rag.nlp import huqie, tokenize
4
  from deepdoc.parser import PdfParser
5
  from rag.utils import num_tokens_from_string
6
 
7
 
8
  class Pdf(PdfParser):
 
 
 
 
9
  def __call__(self, filename, binary=None, from_page=0,
10
  to_page=100000, zoomin=3, callback=None):
11
  self.__images__(
 
1
  import copy
2
  import re
3
+
4
+ from api.db import ParserType
5
  from rag.nlp import huqie, tokenize
6
  from deepdoc.parser import PdfParser
7
  from rag.utils import num_tokens_from_string
8
 
9
 
10
  class Pdf(PdfParser):
11
+ def __init__(self):
12
+ self.model_speciess = ParserType.MANUAL.value
13
+ super().__init__()
14
+
15
  def __call__(self, filename, binary=None, from_page=0,
16
  to_page=100000, zoomin=3, callback=None):
17
  self.__images__(
rag/app/naive.py CHANGED
@@ -30,11 +30,21 @@ class Pdf(PdfParser):
30
 
31
  from timeit import default_timer as timer
32
  start = timer()
 
33
  self._layouts_rec(zoomin)
34
- callback(0.77, "Layout analysis finished")
 
 
 
 
 
 
 
 
 
35
  cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
36
- self._naive_vertical_merge()
37
- return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
38
 
39
 
40
  def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
@@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
44
  Successive text will be sliced into pieces using 'delimiter'.
45
  Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
46
  """
 
 
47
  doc = {
48
  "docnm_kwd": filename,
49
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
50
  }
51
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
 
52
  pdf_parser = None
53
  sections = []
54
  if re.search(r"\.docx?$", filename, re.IGNORECASE):
@@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
58
  callback(0.8, "Finish parsing.")
59
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
60
  pdf_parser = Pdf()
61
- sections = pdf_parser(filename if not binary else binary,
62
  from_page=from_page, to_page=to_page, callback=callback)
 
 
 
 
 
 
 
 
 
 
 
63
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
64
  callback(0.1, "Start to parse.")
65
  txt = ""
@@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
79
 
80
  parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
81
  cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
82
- eng = lang.lower() == "english"#is_english(cks)
83
- res = []
84
  # wrap up to es documents
85
  for ck in cks:
86
  print("--", ck)
 
30
 
31
  from timeit import default_timer as timer
32
  start = timer()
33
+ start = timer()
34
  self._layouts_rec(zoomin)
35
+ callback(0.5, "Layout analysis finished.")
36
+ print("paddle layouts:", timer() - start)
37
+ self._table_transformer_job(zoomin)
38
+ callback(0.7, "Table analysis finished.")
39
+ self._text_merge()
40
+ self._concat_downward(concat_between_pages=False)
41
+ self._filter_forpages()
42
+ callback(0.77, "Text merging finished")
43
+ tbls = self._extract_table_figure(True, zoomin, False)
44
+
45
  cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
46
+ #self._naive_vertical_merge()
47
+ return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
48
 
49
 
50
  def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
 
54
  Successive text will be sliced into pieces using 'delimiter'.
55
  Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
56
  """
57
+
58
+ eng = lang.lower() == "english"#is_english(cks)
59
  doc = {
60
  "docnm_kwd": filename,
61
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
62
  }
63
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
64
+ res = []
65
  pdf_parser = None
66
  sections = []
67
  if re.search(r"\.docx?$", filename, re.IGNORECASE):
 
71
  callback(0.8, "Finish parsing.")
72
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
73
  pdf_parser = Pdf()
74
+ sections, tbls = pdf_parser(filename if not binary else binary,
75
  from_page=from_page, to_page=to_page, callback=callback)
76
+ # add tables
77
+ for img, rows in tbls:
78
+ bs = 10
79
+ de = ";" if eng else ";"
80
+ for i in range(0, len(rows), bs):
81
+ d = copy.deepcopy(doc)
82
+ r = de.join(rows[i:i + bs])
83
+ r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
84
+ tokenize(d, r, eng)
85
+ d["image"] = img
86
+ res.append(d)
87
  elif re.search(r"\.txt$", filename, re.IGNORECASE):
88
  callback(0.1, "Start to parse.")
89
  txt = ""
 
103
 
104
  parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
105
  cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
106
+
 
107
  # wrap up to es documents
108
  for ck in cks:
109
  print("--", ck)
rag/svr/task_executor.py CHANGED
@@ -37,7 +37,7 @@ from rag.nlp import search
37
  from io import BytesIO
38
  import pandas as pd
39
 
40
- from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
41
 
42
  from api.db import LLMType, ParserType
43
  from api.db.services.document_service import DocumentService
@@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory
48
  BATCH_SIZE = 64
49
 
50
  FACTORY = {
51
- ParserType.GENERAL.value: laws,
52
  ParserType.PAPER.value: paper,
53
  ParserType.BOOK.value: book,
54
  ParserType.PRESENTATION.value: presentation,
 
37
  from io import BytesIO
38
  import pandas as pd
39
 
40
+ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive
41
 
42
  from api.db import LLMType, ParserType
43
  from api.db.services.document_service import DocumentService
 
48
  BATCH_SIZE = 64
49
 
50
  FACTORY = {
51
+ ParserType.NAIVE.value: naive,
52
  ParserType.PAPER.value: paper,
53
  ParserType.BOOK.value: book,
54
  ParserType.PRESENTATION.value: presentation,