LiuHua Feiue Kevin Hu commited on
Commit
cf772f7
·
1 Parent(s): 4d0a7c7

complete implementation of dataset SDK (#2147)

Browse files

### What problem does this PR solve?

Complete implementation of dataset SDK.
#1102

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>

api/apps/sdk/dataset.py CHANGED
@@ -15,82 +15,156 @@
15
  #
16
  from flask import request
17
 
18
- from api.db import StatusEnum
19
- from api.db.db_models import APIToken
 
 
 
20
  from api.db.services.knowledgebase_service import KnowledgebaseService
21
  from api.db.services.user_service import TenantService
22
  from api.settings import RetCode
23
  from api.utils import get_uuid
24
- from api.utils.api_utils import get_data_error_result
25
- from api.utils.api_utils import get_json_result
26
 
27
 
28
  @manager.route('/save', methods=['POST'])
29
- def save():
 
30
  req = request.json
31
- token = request.headers.get('Authorization').split()[1]
32
- objs = APIToken.query(token=token)
33
- if not objs:
34
- return get_json_result(
35
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
36
- tenant_id = objs[0].tenant_id
37
  e, t = TenantService.get_by_id(tenant_id)
38
- if not e:
39
- return get_data_error_result(retmsg="Tenant not found.")
40
  if "id" not in req:
 
 
 
 
 
 
41
  req['id'] = get_uuid()
42
  req["name"] = req["name"].strip()
43
  if req["name"] == "":
44
  return get_data_error_result(
45
- retmsg="Name is not empty")
46
- if KnowledgebaseService.query(name=req["name"]):
47
  return get_data_error_result(
48
- retmsg="Duplicated knowledgebase name")
49
  req["tenant_id"] = tenant_id
50
  req['created_by'] = tenant_id
51
  req['embd_id'] = t.embd_id
52
  if not KnowledgebaseService.save(**req):
53
- return get_data_error_result(retmsg="Data saving error")
54
- req.pop('created_by')
55
- keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
56
- 'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
57
- for old_key,new_key in keys_to_rename.items():
58
- if old_key in req:
59
- req[new_key]=req.pop(old_key)
60
  return get_json_result(data=req)
61
  else:
62
- if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
63
- return get_data_error_result(
64
- retmsg="Can't change tenant_id or embedding_model")
 
65
 
66
- e, kb = KnowledgebaseService.get_by_id(req["id"])
67
- if not e:
68
- return get_data_error_result(
69
- retmsg="Can't find this knowledgebase!")
70
 
71
  if not KnowledgebaseService.query(
72
  created_by=tenant_id, id=req["id"]):
73
  return get_json_result(
74
- data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
75
  retcode=RetCode.OPERATING_ERROR)
76
 
77
- if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
78
- return get_data_error_result(
79
- retmsg="Can't change document_count or chunk_count ")
80
 
81
- if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
82
- return get_data_error_result(
83
- retmsg="if chunk count is not 0, parser method is not changable. ")
 
84
 
 
 
 
 
85
 
86
- if req["name"].lower() != kb.name.lower() \
87
- and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
88
- status=StatusEnum.VALID.value)) > 0:
89
- return get_data_error_result(
90
- retmsg="Duplicated knowledgebase name.")
 
 
 
 
 
91
 
92
  del req["id"]
93
- req['created_by'] = tenant_id
94
  if not KnowledgebaseService.update_by_id(kb.id, req):
95
- return get_data_error_result(retmsg="Data update error ")
96
  return get_json_result(data=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  #
16
  from flask import request
17
 
18
+ from api.db import StatusEnum, FileSource
19
+ from api.db.db_models import File
20
+ from api.db.services.document_service import DocumentService
21
+ from api.db.services.file2document_service import File2DocumentService
22
+ from api.db.services.file_service import FileService
23
  from api.db.services.knowledgebase_service import KnowledgebaseService
24
  from api.db.services.user_service import TenantService
25
  from api.settings import RetCode
26
  from api.utils import get_uuid
27
+ from api.utils.api_utils import get_json_result, token_required, get_data_error_result
 
28
 
29
 
30
  @manager.route('/save', methods=['POST'])
31
+ @token_required
32
+ def save(tenant_id):
33
  req = request.json
 
 
 
 
 
 
34
  e, t = TenantService.get_by_id(tenant_id)
 
 
35
  if "id" not in req:
36
+ if "tenant_id" in req or "embd_id" in req:
37
+ return get_data_error_result(
38
+ retmsg="Tenant_id or embedding_model must not be provided")
39
+ if "name" not in req:
40
+ return get_data_error_result(
41
+ retmsg="Name is not empty!")
42
  req['id'] = get_uuid()
43
  req["name"] = req["name"].strip()
44
  if req["name"] == "":
45
  return get_data_error_result(
46
+ retmsg="Name is not empty string!")
47
+ if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
48
  return get_data_error_result(
49
+ retmsg="Duplicated knowledgebase name in creating dataset.")
50
  req["tenant_id"] = tenant_id
51
  req['created_by'] = tenant_id
52
  req['embd_id'] = t.embd_id
53
  if not KnowledgebaseService.save(**req):
54
+ return get_data_error_result(retmsg="Create dataset error.(Database error)")
 
 
 
 
 
 
55
  return get_json_result(data=req)
56
  else:
57
+ if "tenant_id" in req:
58
+ if req["tenant_id"] != tenant_id:
59
+ return get_data_error_result(
60
+ retmsg="Can't change tenant_id.")
61
 
62
+ if "embd_id" in req:
63
+ if req["embd_id"] != t.embd_id:
64
+ return get_data_error_result(
65
+ retmsg="Can't change embedding_model.")
66
 
67
  if not KnowledgebaseService.query(
68
  created_by=tenant_id, id=req["id"]):
69
  return get_json_result(
70
+ data=False, retmsg='You do not own the dataset.',
71
  retcode=RetCode.OPERATING_ERROR)
72
 
73
+ e, kb = KnowledgebaseService.get_by_id(req["id"])
 
 
74
 
75
+ if "chunk_num" in req:
76
+ if req["chunk_num"] != kb.chunk_num:
77
+ return get_data_error_result(
78
+ retmsg="Can't change chunk_count.")
79
 
80
+ if "doc_num" in req:
81
+ if req['doc_num'] != kb.doc_num:
82
+ return get_data_error_result(
83
+ retmsg="Can't change document_count.")
84
 
85
+ if "parser_id" in req:
86
+ if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
87
+ return get_data_error_result(
88
+ retmsg="if chunk count is not 0, parse method is not changable.")
89
+ if "name" in req:
90
+ if req["name"].lower() != kb.name.lower() \
91
+ and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
92
+ status=StatusEnum.VALID.value)) > 0:
93
+ return get_data_error_result(
94
+ retmsg="Duplicated knowledgebase name in updating dataset.")
95
 
96
  del req["id"]
 
97
  if not KnowledgebaseService.update_by_id(kb.id, req):
98
+ return get_data_error_result(retmsg="Update dataset error.(Database error)")
99
  return get_json_result(data=True)
100
+
101
+
102
+ @manager.route('/delete', methods=['DELETE'])
103
+ @token_required
104
+ def delete(tenant_id):
105
+ req = request.args
106
+ kbs = KnowledgebaseService.query(
107
+ created_by=tenant_id, id=req["id"])
108
+ if not kbs:
109
+ return get_json_result(
110
+ data=False, retmsg='You do not own the dataset',
111
+ retcode=RetCode.OPERATING_ERROR)
112
+
113
+ for doc in DocumentService.query(kb_id=req["id"]):
114
+ if not DocumentService.remove_document(doc, kbs[0].tenant_id):
115
+ return get_data_error_result(
116
+ retmsg="Remove document error.(Database error)")
117
+ f2d = File2DocumentService.get_by_document_id(doc.id)
118
+ FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
119
+ File2DocumentService.delete_by_document_id(doc.id)
120
+
121
+ if not KnowledgebaseService.delete_by_id(req["id"]):
122
+ return get_data_error_result(
123
+ retmsg="Delete dataset error.(Database error)")
124
+ return get_json_result(data=True)
125
+
126
+
127
+ @manager.route('/list', methods=['GET'])
128
+ @token_required
129
+ def list_datasets(tenant_id):
130
+ page_number = int(request.args.get("page", 1))
131
+ items_per_page = int(request.args.get("page_size", 1024))
132
+ orderby = request.args.get("orderby", "create_time")
133
+ desc = bool(request.args.get("desc", True))
134
+ tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
135
+ kbs = KnowledgebaseService.get_by_tenant_ids(
136
+ [m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
137
+ return get_json_result(data=kbs)
138
+
139
+
140
+ @manager.route('/detail', methods=['GET'])
141
+ @token_required
142
+ def detail(tenant_id):
143
+ req = request.args
144
+ if "id" in req:
145
+ id = req["id"]
146
+ kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
147
+ if not kb:
148
+ return get_json_result(
149
+ data=False, retmsg='You do not own the dataset',
150
+ retcode=RetCode.OPERATING_ERROR)
151
+ if "name" in req:
152
+ name = req["name"]
153
+ if kb[0].name != name:
154
+ return get_json_result(
155
+ data=False, retmsg='You do not own the dataset',
156
+ retcode=RetCode.OPERATING_ERROR)
157
+ e, k = KnowledgebaseService.get_by_id(id)
158
+ return get_json_result(data=k.to_dict())
159
+ else:
160
+ if "name" in req:
161
+ name = req["name"]
162
+ e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
163
+ if not e:
164
+ return get_json_result(
165
+ data=False, retmsg='You do not own the dataset',
166
+ retcode=RetCode.OPERATING_ERROR)
167
+ return get_json_result(data=k.to_dict())
168
+ else:
169
+ return get_data_error_result(
170
+ retmsg="At least one of `id` or `name` must be provided.")
api/utils/api_utils.py CHANGED
@@ -13,30 +13,32 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import json
17
  import random
18
  import time
 
19
  from functools import wraps
 
20
  from io import BytesIO
 
 
 
 
21
  from flask import (
22
  Response, jsonify, send_file, make_response,
23
  request as flask_request,
24
  )
25
  from werkzeug.http import HTTP_STATUS_CODES
26
 
27
- from api.utils import json_dumps
28
- from api.settings import RetCode
29
  from api.settings import (
30
  REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
31
  stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
32
  )
33
- import requests
34
- import functools
35
  from api.utils import CustomJSONEncoder
36
- from uuid import uuid1
37
- from base64 import b64encode
38
- from hmac import HMAC
39
- from urllib.parse import quote, urlencode
40
 
41
  requests.models.complexjson.dumps = functools.partial(
42
  json.dumps, cls=CustomJSONEncoder)
@@ -96,7 +98,6 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
96
 
97
  def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
98
  data=None, job_id=None, meta=None):
99
- import re
100
  result_dict = {
101
  "retcode": retcode,
102
  "retmsg": retmsg,
@@ -145,7 +146,8 @@ def server_error_response(e):
145
  return get_json_result(
146
  retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
147
  if repr(e).find("index_not_found_exception") >= 0:
148
- return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
 
149
 
150
  return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
151
 
@@ -190,7 +192,9 @@ def validate_request(*args, **kwargs):
190
  return get_json_result(
191
  retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
192
  return func(*_args, **_kwargs)
 
193
  return decorated_function
 
194
  return wrapper
195
 
196
 
@@ -217,7 +221,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
217
 
218
 
219
  def construct_response(retcode=RetCode.SUCCESS,
220
- retmsg='success', data=None, auth=None):
221
  result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
222
  response_dict = {}
223
  for key, value in result_dict.items():
@@ -235,6 +239,7 @@ def construct_response(retcode=RetCode.SUCCESS,
235
  response.headers["Access-Control-Expose-Headers"] = "Authorization"
236
  return response
237
 
 
238
  def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
239
  import re
240
  result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
@@ -263,7 +268,23 @@ def construct_error_response(e):
263
  pass
264
  if len(e.args) > 1:
265
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
266
- if repr(e).find("index_not_found_exception") >=0:
267
- return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
 
268
 
269
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import functools
17
  import json
18
  import random
19
  import time
20
+ from base64 import b64encode
21
  from functools import wraps
22
+ from hmac import HMAC
23
  from io import BytesIO
24
+ from urllib.parse import quote, urlencode
25
+ from uuid import uuid1
26
+
27
+ import requests
28
  from flask import (
29
  Response, jsonify, send_file, make_response,
30
  request as flask_request,
31
  )
32
  from werkzeug.http import HTTP_STATUS_CODES
33
 
34
+ from api.db.db_models import APIToken
 
35
  from api.settings import (
36
  REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
37
  stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
38
  )
39
+ from api.settings import RetCode
 
40
  from api.utils import CustomJSONEncoder
41
+ from api.utils import json_dumps
 
 
 
42
 
43
  requests.models.complexjson.dumps = functools.partial(
44
  json.dumps, cls=CustomJSONEncoder)
 
98
 
99
  def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
100
  data=None, job_id=None, meta=None):
 
101
  result_dict = {
102
  "retcode": retcode,
103
  "retmsg": retmsg,
 
146
  return get_json_result(
147
  retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
148
  if repr(e).find("index_not_found_exception") >= 0:
149
+ return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
150
+ retmsg="No chunk found, please upload file and parse it.")
151
 
152
  return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
153
 
 
192
  return get_json_result(
193
  retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
194
  return func(*_args, **_kwargs)
195
+
196
  return decorated_function
197
+
198
  return wrapper
199
 
200
 
 
221
 
222
 
223
  def construct_response(retcode=RetCode.SUCCESS,
224
+ retmsg='success', data=None, auth=None):
225
  result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
226
  response_dict = {}
227
  for key, value in result_dict.items():
 
239
  response.headers["Access-Control-Expose-Headers"] = "Authorization"
240
  return response
241
 
242
+
243
  def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
244
  import re
245
  result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
 
268
  pass
269
  if len(e.args) > 1:
270
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
271
+ if repr(e).find("index_not_found_exception") >= 0:
272
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR,
273
+ message="No chunk found, please upload file and parse it.")
274
 
275
  return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
276
+
277
+
278
+ def token_required(func):
279
+ @wraps(func)
280
+ def decorated_function(*args, **kwargs):
281
+ token = flask_request.headers.get('Authorization').split()[1]
282
+ objs = APIToken.query(token=token)
283
+ if not objs:
284
+ return get_json_result(
285
+ data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
286
+ )
287
+ kwargs['tenant_id'] = objs[0].tenant_id
288
+ return func(*args, **kwargs)
289
+
290
+ return decorated_function
sdk/python/ragflow/modules/base.py CHANGED
@@ -18,13 +18,17 @@ class Base(object):
18
  pr[name] = value
19
  return pr
20
 
21
-
22
  def post(self, path, param):
23
- res = self.rag.post(path,param)
24
  return res
25
 
26
- def get(self, path, params=''):
27
- res = self.rag.get(path,params)
28
  return res
29
 
 
 
 
30
 
 
 
 
18
  pr[name] = value
19
  return pr
20
 
 
21
  def post(self, path, param):
22
+ res = self.rag.post(path, param)
23
  return res
24
 
25
+ def get(self, path, params):
26
+ res = self.rag.get(path, params)
27
  return res
28
 
29
+ def rm(self, path, params):
30
+ res = self.rag.delete(path, params)
31
+ return res
32
 
33
+ def __str__(self):
34
+ return str(self.to_json())
sdk/python/ragflow/modules/dataset.py CHANGED
@@ -21,18 +21,36 @@ class DataSet(Base):
21
  self.permission = "me"
22
  self.document_count = 0
23
  self.chunk_count = 0
24
- self.parser_method = "naive"
25
  self.parser_config = None
 
 
 
 
 
 
 
 
 
 
 
26
  super().__init__(rag, res_dict)
27
 
28
- def save(self):
29
  res = self.post('/dataset/save',
30
  {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
31
  "description": self.description, "language": self.language, "embd_id": self.embedding_model,
32
  "permission": self.permission,
33
- "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
34
  "parser_config": self.parser_config.to_json()
35
  })
36
  res = res.json()
37
- if not res.get("retmsg"): return True
38
- raise Exception(res["retmsg"])
 
 
 
 
 
 
 
 
21
  self.permission = "me"
22
  self.document_count = 0
23
  self.chunk_count = 0
24
+ self.parse_method = "naive"
25
  self.parser_config = None
26
+ for k in list(res_dict.keys()):
27
+ if k == "embd_id":
28
+ res_dict["embedding_model"] = res_dict[k]
29
+ if k == "parser_id":
30
+ res_dict['parse_method'] = res_dict[k]
31
+ if k == "doc_num":
32
+ res_dict["document_count"] = res_dict[k]
33
+ if k == "chunk_num":
34
+ res_dict["chunk_count"] = res_dict[k]
35
+ if k not in self.__dict__:
36
+ res_dict.pop(k)
37
  super().__init__(rag, res_dict)
38
 
39
+ def save(self) -> bool:
40
  res = self.post('/dataset/save',
41
  {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
42
  "description": self.description, "language": self.language, "embd_id": self.embedding_model,
43
  "permission": self.permission,
44
+ "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
45
  "parser_config": self.parser_config.to_json()
46
  })
47
  res = res.json()
48
+ if res.get("retmsg") == "success": return True
49
+ raise Exception(res["retmsg"])
50
+
51
+ def delete(self) -> bool:
52
+ res = self.rm('/dataset/delete',
53
+ {"id": self.id})
54
+ res = res.json()
55
+ if res.get("retmsg") == "success": return True
56
+ raise Exception(res["retmsg"])
sdk/python/ragflow/ragflow.py CHANGED
@@ -13,6 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
16
  import requests
17
 
18
  from .modules.dataset import DataSet
@@ -25,30 +27,54 @@ class RAGFlow:
25
  """
26
  self.user_key = user_key
27
  self.api_url = f"{base_url}/api/{version}"
28
- self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)}
29
 
30
  def post(self, path, param):
31
  res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
32
  return res
33
 
34
- def get(self, path, params=''):
35
- res = requests.get(self.api_url + path, params=params, headers=self.authorization_header)
 
 
 
 
36
  return res
37
 
38
- def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me",
39
- document_count:int=0,chunk_count:int=0,parser_method:str="naive",
40
- parser_config:DataSet.ParserConfig=None):
 
41
  if parser_config is None:
42
- parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12})
43
- parser_config=parser_config.to_json()
44
- res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission,
45
- "doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method,
46
- "parser_config":parser_config
47
- }
48
- )
 
 
 
49
  res = res.json()
50
- if not res.get("retmsg"):
51
  return DataSet(self, res["data"])
52
  raise Exception(res["retmsg"])
53
 
 
 
 
 
 
 
 
 
 
 
54
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ from typing import List
17
+
18
  import requests
19
 
20
  from .modules.dataset import DataSet
 
27
  """
28
  self.user_key = user_key
29
  self.api_url = f"{base_url}/api/{version}"
30
+ self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
31
 
32
  def post(self, path, param):
33
  res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
34
  return res
35
 
36
+ def get(self, path, params=None):
37
+ res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
38
+ return res
39
+
40
+ def delete(self, path, params):
41
+ res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
42
  return res
43
 
44
+ def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English",
45
+ permission: str = "me",
46
+ document_count: int = 0, chunk_count: int = 0, parse_method: str = "naive",
47
+ parser_config: DataSet.ParserConfig = None) -> DataSet:
48
  if parser_config is None:
49
+ parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True,
50
+ "delimiter": "\n!?。;!?", "task_page_size": 12})
51
+ parser_config = parser_config.to_json()
52
+ res = self.post("/dataset/save",
53
+ {"name": name, "avatar": avatar, "description": description, "language": language,
54
+ "permission": permission,
55
+ "doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method,
56
+ "parser_config": parser_config
57
+ }
58
+ )
59
  res = res.json()
60
+ if res.get("retmsg") == "success":
61
  return DataSet(self, res["data"])
62
  raise Exception(res["retmsg"])
63
 
64
+ def list_datasets(self, page: int = 1, page_size: int = 150, orderby: str = "create_time", desc: bool = True) -> \
65
+ List[DataSet]:
66
+ res = self.get("/dataset/list", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
67
+ res = res.json()
68
+ result_list = []
69
+ if res.get("retmsg") == "success":
70
+ for data in res['data']:
71
+ result_list.append(DataSet(self, data))
72
+ return result_list
73
+ raise Exception(res["retmsg"])
74
 
75
+ def get_dataset(self, id: str = None, name: str = None) -> DataSet:
76
+ res = self.get("/dataset/detail", {"id": id, "name": name})
77
+ res = res.json()
78
+ if res.get("retmsg") == "success":
79
+ return DataSet(self, res['data'])
80
+ raise Exception(res["retmsg"])
sdk/python/test/t_dataset.py CHANGED
@@ -7,7 +7,7 @@ from test_sdkbase import TestSdk
7
  class TestDataset(TestSdk):
8
  def test_create_dataset_with_success(self):
9
  """
10
- Test creating dataset with success
11
  """
12
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
13
  ds = rag.create_dataset("God")
@@ -18,15 +18,46 @@ class TestDataset(TestSdk):
18
 
19
  def test_update_dataset_with_success(self):
20
  """
21
- Test updating dataset with success.
22
  """
23
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
24
  ds = rag.create_dataset("ABC")
25
  if isinstance(ds, DataSet):
26
- assert ds.name == "ABC", "Name does not match."
27
  ds.name = 'DEF'
28
  res = ds.save()
29
- assert res is True, f"Failed to update dataset, error: {res}"
 
 
30
 
 
 
 
 
 
 
 
 
 
 
31
  else:
32
- assert False, f"Failed to create dataset, error: {ds}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class TestDataset(TestSdk):
8
  def test_create_dataset_with_success(self):
9
  """
10
+ Test creating a dataset with success
11
  """
12
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
13
  ds = rag.create_dataset("God")
 
18
 
19
  def test_update_dataset_with_success(self):
20
  """
21
+ Test updating a dataset with success.
22
  """
23
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
24
  ds = rag.create_dataset("ABC")
25
  if isinstance(ds, DataSet):
26
+ assert ds.name == "ABC", "Name does not match."
27
  ds.name = 'DEF'
28
  res = ds.save()
29
+ assert res is True, f"Failed to update dataset, error: {res}"
30
+ else:
31
+ assert False, f"Failed to create dataset, error: {ds}"
32
 
33
+ def test_delete_dataset_with_success(self):
34
+ """
35
+ Test deleting a dataset with success
36
+ """
37
+ rag = RAGFlow(API_KEY, HOST_ADDRESS)
38
+ ds = rag.create_dataset("MA")
39
+ if isinstance(ds, DataSet):
40
+ assert ds.name == "MA", "Name does not match."
41
+ res = ds.delete()
42
+ assert res is True, f"Failed to delete dataset, error: {res}"
43
  else:
44
+ assert False, f"Failed to create dataset, error: {ds}"
45
+
46
+ def test_list_datasets_with_success(self):
47
+ """
48
+ Test listing datasets with success
49
+ """
50
+ rag = RAGFlow(API_KEY, HOST_ADDRESS)
51
+ list_datasets = rag.list_datasets()
52
+ assert len(list_datasets) > 0, "Do not exist any dataset"
53
+ for ds in list_datasets:
54
+ assert isinstance(ds, DataSet), "Existence type is not dataset."
55
+
56
+ def test_get_detail_dataset_with_success(self):
57
+ """
58
+ Test getting a dataset's detail with success
59
+ """
60
+ rag = RAGFlow(API_KEY, HOST_ADDRESS)
61
+ ds = rag.get_dataset(name="God")
62
+ assert isinstance(ds, DataSet), f"Failed to get dataset, error: {ds}."
63
+ assert ds.name == "God", "Name does not match"