KevinHuSh
commited on
Commit
·
f666f56
1
Parent(s):
4873964
fix user login issue (#85)
Browse files- api/apps/user_app.py +47 -60
- api/db/__init__.py +0 -1
- api/db/db_models.py +1 -1
- api/db/init_data.py +2 -2
- api/db/services/user_service.py +10 -3
- api/settings.py +1 -1
- deepdoc/parser/pdf_parser.py +1 -1
- deepdoc/vision/layout_recognizer.py +1 -2
- deepdoc/vision/table_structure_recognizer.py +1 -2
- rag/app/manual.py +6 -0
- rag/app/naive.py +29 -6
- rag/svr/task_executor.py +2 -2
api/apps/user_app.py
CHANGED
@@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
|
|
33 |
|
34 |
@manager.route('/login', methods=['POST', 'GET'])
|
35 |
def login():
|
36 |
-
userinfo = None
|
37 |
login_channel = "password"
|
38 |
-
if
|
39 |
-
login_channel = session["access_token_from"]
|
40 |
-
if session["access_token_from"] == "github":
|
41 |
-
userinfo = user_info_from_github(session["access_token"])
|
42 |
-
elif not request.json:
|
43 |
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
44 |
retmsg='Unautherized!')
|
45 |
|
46 |
-
email = request.json.get('email'
|
47 |
users = UserService.query(email=email)
|
48 |
-
if not users:
|
49 |
-
if request.json is not None:
|
50 |
-
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
51 |
-
avatar = ""
|
52 |
-
try:
|
53 |
-
avatar = download_img(userinfo["avatar_url"])
|
54 |
-
except Exception as e:
|
55 |
-
stat_logger.exception(e)
|
56 |
-
user_id = get_uuid()
|
57 |
-
try:
|
58 |
-
users = user_register(user_id, {
|
59 |
-
"access_token": session["access_token"],
|
60 |
-
"email": userinfo["email"],
|
61 |
-
"avatar": avatar,
|
62 |
-
"nickname": userinfo["login"],
|
63 |
-
"login_channel": login_channel,
|
64 |
-
"last_login_time": get_format_time(),
|
65 |
-
"is_superuser": False,
|
66 |
-
})
|
67 |
-
if not users: raise Exception('Register user failure.')
|
68 |
-
if len(users) > 1: raise Exception('Same E-mail exist!')
|
69 |
-
user = users[0]
|
70 |
-
login_user(user)
|
71 |
-
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
72 |
-
except Exception as e:
|
73 |
-
rollback_user_registration(user_id)
|
74 |
-
stat_logger.exception(e)
|
75 |
-
return server_error_response(e)
|
76 |
-
elif not request.json:
|
77 |
-
login_user(users[0])
|
78 |
-
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
79 |
|
80 |
password = request.json.get('password')
|
81 |
try:
|
@@ -97,28 +62,50 @@ def login():
|
|
97 |
|
98 |
@manager.route('/github_callback', methods=['GET'])
|
99 |
def github_callback():
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
retmsg=res["error_description"])
|
111 |
-
|
112 |
-
if "user:email" not in res["scope"].split(","):
|
113 |
-
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
114 |
-
|
115 |
-
session["access_token"] = res["access_token"]
|
116 |
-
session["access_token_from"] = "github"
|
117 |
-
return redirect(url_for("user.login"), code=307)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
def user_info_from_github(access_token):
|
@@ -208,7 +195,7 @@ def user_register(user_id, user):
|
|
208 |
for llm in LLMService.query(fid=LLM_FACTORY):
|
209 |
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
210 |
|
211 |
-
if not UserService.
|
212 |
TenantService.insert(**tenant)
|
213 |
UserTenantService.insert(**usr_tenant)
|
214 |
TenantLLMService.insert_many(tenant_llm)
|
|
|
33 |
|
34 |
@manager.route('/login', methods=['POST', 'GET'])
|
35 |
def login():
|
|
|
36 |
login_channel = "password"
|
37 |
+
if not request.json:
|
|
|
|
|
|
|
|
|
38 |
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
39 |
retmsg='Unautherized!')
|
40 |
|
41 |
+
email = request.json.get('email', "")
|
42 |
users = UserService.query(email=email)
|
43 |
+
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
password = request.json.get('password')
|
46 |
try:
|
|
|
62 |
|
63 |
@manager.route('/github_callback', methods=['GET'])
|
64 |
def github_callback():
|
65 |
+
import requests
|
66 |
+
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
67 |
+
"client_id": GITHUB_OAUTH.get("client_id"),
|
68 |
+
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
69 |
+
"code": request.args.get('code')
|
70 |
+
}, headers={"Accept": "application/json"})
|
71 |
+
res = res.json()
|
72 |
+
if "error" in res:
|
73 |
+
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
74 |
+
retmsg=res["error_description"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
if "user:email" not in res["scope"].split(","):
|
77 |
+
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
78 |
+
|
79 |
+
session["access_token"] = res["access_token"]
|
80 |
+
session["access_token_from"] = "github"
|
81 |
+
userinfo = user_info_from_github(session["access_token"])
|
82 |
+
users = UserService.query(email=userinfo["email"])
|
83 |
+
user_id = get_uuid()
|
84 |
+
if not users:
|
85 |
+
try:
|
86 |
+
try:
|
87 |
+
avatar = download_img(userinfo["avatar_url"])
|
88 |
+
except Exception as e:
|
89 |
+
stat_logger.exception(e)
|
90 |
+
avatar = ""
|
91 |
+
users = user_register(user_id, {
|
92 |
+
"access_token": session["access_token"],
|
93 |
+
"email": userinfo["email"],
|
94 |
+
"avatar": avatar,
|
95 |
+
"nickname": userinfo["login"],
|
96 |
+
"login_channel": "github",
|
97 |
+
"last_login_time": get_format_time(),
|
98 |
+
"is_superuser": False,
|
99 |
+
})
|
100 |
+
if not users: raise Exception('Register user failure.')
|
101 |
+
if len(users) > 1: raise Exception('Same E-mail exist!')
|
102 |
+
user = users[0]
|
103 |
+
login_user(user)
|
104 |
+
except Exception as e:
|
105 |
+
rollback_user_registration(user_id)
|
106 |
+
stat_logger.exception(e)
|
107 |
+
|
108 |
+
return redirect("/knowledge")
|
109 |
|
110 |
|
111 |
def user_info_from_github(access_token):
|
|
|
195 |
for llm in LLMService.query(fid=LLM_FACTORY):
|
196 |
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
197 |
|
198 |
+
if not UserService.save(**user):return
|
199 |
TenantService.insert(**tenant)
|
200 |
UserTenantService.insert(**usr_tenant)
|
201 |
TenantLLMService.insert_many(tenant_llm)
|
api/db/__init__.py
CHANGED
@@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
|
|
69 |
|
70 |
|
71 |
class ParserType(StrEnum):
|
72 |
-
GENERAL = "general"
|
73 |
PRESENTATION = "presentation"
|
74 |
LAWS = "laws"
|
75 |
MANUAL = "manual"
|
|
|
69 |
|
70 |
|
71 |
class ParserType(StrEnum):
|
|
|
72 |
PRESENTATION = "presentation"
|
73 |
LAWS = "laws"
|
74 |
MANUAL = "manual"
|
api/db/db_models.py
CHANGED
@@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
|
|
475 |
similarity_threshold = FloatField(default=0.2)
|
476 |
vector_similarity_weight = FloatField(default=0.3)
|
477 |
|
478 |
-
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.
|
479 |
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
480 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
481 |
|
|
|
475 |
similarity_threshold = FloatField(default=0.2)
|
476 |
vector_similarity_weight = FloatField(default=0.3)
|
477 |
|
478 |
+
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
|
479 |
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
480 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
481 |
|
api/db/init_data.py
CHANGED
@@ -30,7 +30,7 @@ def init_superuser():
|
|
30 |
"password": "admin",
|
31 |
"nickname": "admin",
|
32 |
"is_superuser": True,
|
33 |
-
"email": "
|
34 |
"creator": "system",
|
35 |
"status": "1",
|
36 |
}
|
@@ -61,7 +61,7 @@ def init_superuser():
|
|
61 |
TenantService.insert(**tenant)
|
62 |
UserTenantService.insert(**usr_tenant)
|
63 |
TenantLLMService.insert_many(tenant_llm)
|
64 |
-
print("【INFO】Super user initialized. \033[
|
65 |
|
66 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
67 |
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
|
|
30 |
"password": "admin",
|
31 |
"nickname": "admin",
|
32 |
"is_superuser": True,
|
33 |
+
"email": "admin@ragflow.io",
|
34 |
"creator": "system",
|
35 |
"status": "1",
|
36 |
}
|
|
|
61 |
TenantService.insert(**tenant)
|
62 |
UserTenantService.insert(**usr_tenant)
|
63 |
TenantLLMService.insert_many(tenant_llm)
|
64 |
+
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
65 |
|
66 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
67 |
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
api/db/services/user_service.py
CHANGED
@@ -13,6 +13,8 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
|
|
16 |
import peewee
|
17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
18 |
|
@@ -20,7 +22,7 @@ from api.db import UserTenantRole
|
|
20 |
from api.db.db_models import DB, UserTenant
|
21 |
from api.db.db_models import User, Tenant
|
22 |
from api.db.services.common_service import CommonService
|
23 |
-
from api.utils import get_uuid, get_format_time
|
24 |
from api.db import StatusEnum
|
25 |
|
26 |
|
@@ -53,6 +55,11 @@ class UserService(CommonService):
|
|
53 |
kwargs["id"] = get_uuid()
|
54 |
if "password" in kwargs:
|
55 |
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
|
|
|
|
|
|
|
|
|
|
56 |
obj = cls.model(**kwargs).save(force_insert=True)
|
57 |
return obj
|
58 |
|
@@ -66,10 +73,10 @@ class UserService(CommonService):
|
|
66 |
@classmethod
|
67 |
@DB.connection_context()
|
68 |
def update_user(cls, user_id, user_dict):
|
69 |
-
date_time = get_format_time()
|
70 |
with DB.atomic():
|
71 |
if user_dict:
|
72 |
-
user_dict["update_time"] =
|
|
|
73 |
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
74 |
|
75 |
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
from datetime import datetime
|
17 |
+
|
18 |
import peewee
|
19 |
from werkzeug.security import generate_password_hash, check_password_hash
|
20 |
|
|
|
22 |
from api.db.db_models import DB, UserTenant
|
23 |
from api.db.db_models import User, Tenant
|
24 |
from api.db.services.common_service import CommonService
|
25 |
+
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
26 |
from api.db import StatusEnum
|
27 |
|
28 |
|
|
|
55 |
kwargs["id"] = get_uuid()
|
56 |
if "password" in kwargs:
|
57 |
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
58 |
+
|
59 |
+
kwargs["create_time"] = current_timestamp()
|
60 |
+
kwargs["create_date"] = datetime_format(datetime.now())
|
61 |
+
kwargs["update_time"] = current_timestamp()
|
62 |
+
kwargs["update_date"] = datetime_format(datetime.now())
|
63 |
obj = cls.model(**kwargs).save(force_insert=True)
|
64 |
return obj
|
65 |
|
|
|
73 |
@classmethod
|
74 |
@DB.connection_context()
|
75 |
def update_user(cls, user_id, user_dict):
|
|
|
76 |
with DB.atomic():
|
77 |
if user_dict:
|
78 |
+
user_dict["update_time"] = current_timestamp()
|
79 |
+
user_dict["update_date"] = datetime_format(datetime.now())
|
80 |
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
81 |
|
82 |
|
api/settings.py
CHANGED
@@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
|
76 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
77 |
|
78 |
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
79 |
-
PARSERS = LLM.get("parsers", "
|
80 |
|
81 |
# distribution
|
82 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
|
|
76 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
77 |
|
78 |
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
79 |
+
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
80 |
|
81 |
# distribution
|
82 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
deepdoc/parser/pdf_parser.py
CHANGED
@@ -25,7 +25,7 @@ class HuParser:
|
|
25 |
def __init__(self):
|
26 |
self.ocr = OCR()
|
27 |
if not hasattr(self, "model_speciess"):
|
28 |
-
self.model_speciess = ParserType.
|
29 |
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
|
30 |
self.tbl_det = TableStructureRecognizer()
|
31 |
|
|
|
25 |
def __init__(self):
|
26 |
self.ocr = OCR()
|
27 |
if not hasattr(self, "model_speciess"):
|
28 |
+
self.model_speciess = ParserType.NAIVE.value
|
29 |
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
|
30 |
self.tbl_det = TableStructureRecognizer()
|
31 |
|
deepdoc/vision/layout_recognizer.py
CHANGED
@@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer):
|
|
34 |
"Equation",
|
35 |
]
|
36 |
def __init__(self, domain):
|
37 |
-
super().__init__(self.labels, domain,
|
38 |
-
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
39 |
|
40 |
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
41 |
def __is_garbage(b):
|
|
|
34 |
"Equation",
|
35 |
]
|
36 |
def __init__(self, domain):
|
37 |
+
super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
|
|
38 |
|
39 |
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
40 |
def __is_garbage(b):
|
deepdoc/vision/table_structure_recognizer.py
CHANGED
@@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer):
|
|
33 |
]
|
34 |
|
35 |
def __init__(self):
|
36 |
-
super().__init__(self.labels, "tsr",
|
37 |
-
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
38 |
|
39 |
def __call__(self, images, thr=0.2):
|
40 |
tbls = super().__call__(images, thr)
|
|
|
33 |
]
|
34 |
|
35 |
def __init__(self):
|
36 |
+
super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
|
|
37 |
|
38 |
def __call__(self, images, thr=0.2):
|
39 |
tbls = super().__call__(images, thr)
|
rag/app/manual.py
CHANGED
@@ -1,11 +1,17 @@
|
|
1 |
import copy
|
2 |
import re
|
|
|
|
|
3 |
from rag.nlp import huqie, tokenize
|
4 |
from deepdoc.parser import PdfParser
|
5 |
from rag.utils import num_tokens_from_string
|
6 |
|
7 |
|
8 |
class Pdf(PdfParser):
|
|
|
|
|
|
|
|
|
9 |
def __call__(self, filename, binary=None, from_page=0,
|
10 |
to_page=100000, zoomin=3, callback=None):
|
11 |
self.__images__(
|
|
|
1 |
import copy
|
2 |
import re
|
3 |
+
|
4 |
+
from api.db import ParserType
|
5 |
from rag.nlp import huqie, tokenize
|
6 |
from deepdoc.parser import PdfParser
|
7 |
from rag.utils import num_tokens_from_string
|
8 |
|
9 |
|
10 |
class Pdf(PdfParser):
|
11 |
+
def __init__(self):
|
12 |
+
self.model_speciess = ParserType.MANUAL.value
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
def __call__(self, filename, binary=None, from_page=0,
|
16 |
to_page=100000, zoomin=3, callback=None):
|
17 |
self.__images__(
|
rag/app/naive.py
CHANGED
@@ -30,11 +30,21 @@ class Pdf(PdfParser):
|
|
30 |
|
31 |
from timeit import default_timer as timer
|
32 |
start = timer()
|
|
|
33 |
self._layouts_rec(zoomin)
|
34 |
-
callback(0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
36 |
-
self._naive_vertical_merge()
|
37 |
-
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
|
38 |
|
39 |
|
40 |
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
@@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
44 |
Successive text will be sliced into pieces using 'delimiter'.
|
45 |
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
46 |
"""
|
|
|
|
|
47 |
doc = {
|
48 |
"docnm_kwd": filename,
|
49 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
50 |
}
|
51 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
|
|
52 |
pdf_parser = None
|
53 |
sections = []
|
54 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
@@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
58 |
callback(0.8, "Finish parsing.")
|
59 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
60 |
pdf_parser = Pdf()
|
61 |
-
sections = pdf_parser(filename if not binary else binary,
|
62 |
from_page=from_page, to_page=to_page, callback=callback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
64 |
callback(0.1, "Start to parse.")
|
65 |
txt = ""
|
@@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
79 |
|
80 |
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
81 |
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
82 |
-
|
83 |
-
res = []
|
84 |
# wrap up to es documents
|
85 |
for ck in cks:
|
86 |
print("--", ck)
|
|
|
30 |
|
31 |
from timeit import default_timer as timer
|
32 |
start = timer()
|
33 |
+
start = timer()
|
34 |
self._layouts_rec(zoomin)
|
35 |
+
callback(0.5, "Layout analysis finished.")
|
36 |
+
print("paddle layouts:", timer() - start)
|
37 |
+
self._table_transformer_job(zoomin)
|
38 |
+
callback(0.7, "Table analysis finished.")
|
39 |
+
self._text_merge()
|
40 |
+
self._concat_downward(concat_between_pages=False)
|
41 |
+
self._filter_forpages()
|
42 |
+
callback(0.77, "Text merging finished")
|
43 |
+
tbls = self._extract_table_figure(True, zoomin, False)
|
44 |
+
|
45 |
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
46 |
+
#self._naive_vertical_merge()
|
47 |
+
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
|
48 |
|
49 |
|
50 |
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
|
|
54 |
Successive text will be sliced into pieces using 'delimiter'.
|
55 |
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
56 |
"""
|
57 |
+
|
58 |
+
eng = lang.lower() == "english"#is_english(cks)
|
59 |
doc = {
|
60 |
"docnm_kwd": filename,
|
61 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
62 |
}
|
63 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
64 |
+
res = []
|
65 |
pdf_parser = None
|
66 |
sections = []
|
67 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
|
|
71 |
callback(0.8, "Finish parsing.")
|
72 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
73 |
pdf_parser = Pdf()
|
74 |
+
sections, tbls = pdf_parser(filename if not binary else binary,
|
75 |
from_page=from_page, to_page=to_page, callback=callback)
|
76 |
+
# add tables
|
77 |
+
for img, rows in tbls:
|
78 |
+
bs = 10
|
79 |
+
de = ";" if eng else ";"
|
80 |
+
for i in range(0, len(rows), bs):
|
81 |
+
d = copy.deepcopy(doc)
|
82 |
+
r = de.join(rows[i:i + bs])
|
83 |
+
r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
|
84 |
+
tokenize(d, r, eng)
|
85 |
+
d["image"] = img
|
86 |
+
res.append(d)
|
87 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
88 |
callback(0.1, "Start to parse.")
|
89 |
txt = ""
|
|
|
103 |
|
104 |
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
105 |
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
106 |
+
|
|
|
107 |
# wrap up to es documents
|
108 |
for ck in cks:
|
109 |
print("--", ck)
|
rag/svr/task_executor.py
CHANGED
@@ -37,7 +37,7 @@ from rag.nlp import search
|
|
37 |
from io import BytesIO
|
38 |
import pandas as pd
|
39 |
|
40 |
-
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
|
41 |
|
42 |
from api.db import LLMType, ParserType
|
43 |
from api.db.services.document_service import DocumentService
|
@@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory
|
|
48 |
BATCH_SIZE = 64
|
49 |
|
50 |
FACTORY = {
|
51 |
-
ParserType.
|
52 |
ParserType.PAPER.value: paper,
|
53 |
ParserType.BOOK.value: book,
|
54 |
ParserType.PRESENTATION.value: presentation,
|
|
|
37 |
from io import BytesIO
|
38 |
import pandas as pd
|
39 |
|
40 |
+
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive
|
41 |
|
42 |
from api.db import LLMType, ParserType
|
43 |
from api.db.services.document_service import DocumentService
|
|
|
48 |
BATCH_SIZE = 64
|
49 |
|
50 |
FACTORY = {
|
51 |
+
ParserType.NAIVE.value: naive,
|
52 |
ParserType.PAPER.value: paper,
|
53 |
ParserType.BOOK.value: book,
|
54 |
ParserType.PRESENTATION.value: presentation,
|