KevinHuSh commited on
Commit
34b2ab3
·
1 Parent(s): 484e5ab

Test APIs and fix bugs (#41)

Browse files
api/apps/chunk_app.py CHANGED
@@ -214,7 +214,7 @@ def retrieval_test():
214
  question = req["question"]
215
  kb_id = req["kb_id"]
216
  doc_ids = req.get("doc_ids", [])
217
- similarity_threshold = float(req.get("similarity_threshold", 0.4))
218
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
219
  top = int(req.get("top", 1024))
220
  try:
 
214
  question = req["question"]
215
  kb_id = req["kb_id"]
216
  doc_ids = req.get("doc_ids", [])
217
+ similarity_threshold = float(req.get("similarity_threshold", 0.2))
218
  vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
219
  top = int(req.get("top", 1024))
220
  try:
api/apps/conversation_app.py CHANGED
@@ -170,7 +170,7 @@ def chat(dialog, messages, **kwargs):
170
  if p["key"] not in kwargs:
171
  prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
172
 
173
- model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
174
  if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
175
 
176
  question = messages[-1]["content"]
@@ -186,10 +186,10 @@ def chat(dialog, messages, **kwargs):
186
  kwargs["knowledge"] = "\n".join(knowledges)
187
  gen_conf = dialog.llm_setting[dialog.llm_setting_type]
188
  msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
189
- used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97))
190
  if "max_tokens" in gen_conf:
191
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
192
- mdl = ChatModel[model_config.llm_factory](model_config["api_key"], dialog.llm_id)
193
  answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
194
 
195
  answer = retrievaler.insert_citations(answer,
@@ -198,4 +198,6 @@ def chat(dialog, messages, **kwargs):
198
  embd_mdl,
199
  tkweight=1-dialog.vector_similarity_weight,
200
  vtweight=dialog.vector_similarity_weight)
 
 
201
  return {"answer": answer, "retrieval": kbinfos}
 
170
  if p["key"] not in kwargs:
171
  prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
172
 
173
+ model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
174
  if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
175
 
176
  question = messages[-1]["content"]
 
186
  kwargs["knowledge"] = "\n".join(knowledges)
187
  gen_conf = dialog.llm_setting[dialog.llm_setting_type]
188
  msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
189
+ used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
190
  if "max_tokens" in gen_conf:
191
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
192
+ mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
193
  answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
194
 
195
  answer = retrievaler.insert_citations(answer,
 
198
  embd_mdl,
199
  tkweight=1-dialog.vector_similarity_weight,
200
  vtweight=dialog.vector_similarity_weight)
201
+ for c in kbinfos["chunks"]:
202
+ if c.get("vector"):del c["vector"]
203
  return {"answer": answer, "retrieval": kbinfos}
api/apps/document_app.py CHANGED
@@ -11,7 +11,8 @@
11
  # distributed under the License is distributed on an "AS IS" BASIS,
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
- # limitations under the License.
 
15
  #
16
  import base64
17
  import pathlib
@@ -65,7 +66,7 @@ def upload():
65
  while MINIO.obj_exist(kb_id, location):
66
  location += "_"
67
  blob = request.files['file'].read()
68
- MINIO.put(kb_id, filename, blob)
69
  doc = DocumentService.insert({
70
  "id": get_uuid(),
71
  "kb_id": kb.id,
@@ -188,7 +189,10 @@ def rm():
188
  e, doc = DocumentService.get_by_id(req["doc_id"])
189
  if not e:
190
  return get_data_error_result(retmsg="Document not found!")
191
- ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id))
 
 
 
192
 
193
  DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
194
  if not DocumentService.delete_by_id(req["doc_id"]):
 
11
  # distributed under the License is distributed on an "AS IS" BASIS,
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ #
16
  #
17
  import base64
18
  import pathlib
 
66
  while MINIO.obj_exist(kb_id, location):
67
  location += "_"
68
  blob = request.files['file'].read()
69
+ MINIO.put(kb_id, location, blob)
70
  doc = DocumentService.insert({
71
  "id": get_uuid(),
72
  "kb_id": kb.id,
 
189
  e, doc = DocumentService.get_by_id(req["doc_id"])
190
  if not e:
191
  return get_data_error_result(retmsg="Document not found!")
192
+ tenant_id = DocumentService.get_tenant_id(req["doc_id"])
193
+ if not tenant_id:
194
+ return get_data_error_result(retmsg="Tenant not found!")
195
+ ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
196
 
197
  DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
198
  if not DocumentService.delete_by_id(req["doc_id"]):
api/apps/llm_app.py CHANGED
@@ -75,7 +75,7 @@ def list():
75
  llms = LLMService.get_all()
76
  llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
77
  for m in llms:
78
- m["available"] = m.llm_name in mdlnms
79
 
80
  res = {}
81
  for m in llms:
 
75
  llms = LLMService.get_all()
76
  llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
77
  for m in llms:
78
+ m["available"] = m["llm_name"] in mdlnms
79
 
80
  res = {}
81
  for m in llms:
api/db/db_models.py CHANGED
@@ -469,7 +469,7 @@ class Knowledgebase(DataBaseModel):
469
  doc_num = IntegerField(default=0)
470
  token_num = IntegerField(default=0)
471
  chunk_num = IntegerField(default=0)
472
- similarity_threshold = FloatField(default=0.4)
473
  vector_similarity_weight = FloatField(default=0.3)
474
 
475
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
@@ -521,7 +521,7 @@ class Dialog(DataBaseModel):
521
  prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
522
  "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
523
 
524
- similarity_threshold = FloatField(default=0.4)
525
  vector_similarity_weight = FloatField(default=0.3)
526
  top_n = IntegerField(default=6)
527
 
 
469
  doc_num = IntegerField(default=0)
470
  token_num = IntegerField(default=0)
471
  chunk_num = IntegerField(default=0)
472
+ similarity_threshold = FloatField(default=0.2)
473
  vector_similarity_weight = FloatField(default=0.3)
474
 
475
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
 
521
  prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
522
  "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
523
 
524
+ similarity_threshold = FloatField(default=0.2)
525
  vector_similarity_weight = FloatField(default=0.3)
526
  top_n = IntegerField(default=6)
527
 
api/db/services/llm_service.py CHANGED
@@ -63,7 +63,7 @@ class TenantLLMService(CommonService):
63
 
64
  model_config = cls.get_api_key(tenant_id, mdlnm)
65
  if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
66
- model_config = model_config[0].to_dict()
67
  if llm_type == LLMType.EMBEDDING.value:
68
  if model_config["llm_factory"] not in EmbeddingModel: return
69
  return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
 
63
 
64
  model_config = cls.get_api_key(tenant_id, mdlnm)
65
  if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
66
+ model_config = model_config.to_dict()
67
  if llm_type == LLMType.EMBEDDING.value:
68
  if model_config["llm_factory"] not in EmbeddingModel: return
69
  return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
api/utils/file_utils.py CHANGED
@@ -143,7 +143,7 @@ def filename_type(filename):
143
  if re.match(r".*\.pdf$", filename):
144
  return FileType.PDF.value
145
 
146
- if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
147
  return FileType.DOC.value
148
 
149
  if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
 
143
  if re.match(r".*\.pdf$", filename):
144
  return FileType.PDF.value
145
 
146
+ if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
147
  return FileType.DOC.value
148
 
149
  if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
rag/llm/chat_model.py CHANGED
@@ -19,31 +19,39 @@ import os
19
 
20
 
21
  class Base(ABC):
 
 
 
22
  def chat(self, system, history, gen_conf):
23
  raise NotImplementedError("Please implement encode method!")
24
 
25
 
26
  class GptTurbo(Base):
27
- def __init__(self):
28
- self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
 
29
 
30
  def chat(self, system, history, gen_conf):
31
  history.insert(0, {"role": "system", "content": system})
32
  res = self.client.chat.completions.create(
33
- model="gpt-3.5-turbo",
34
  messages=history,
35
  **gen_conf)
36
  return res.choices[0].message.content.strip()
37
 
38
 
 
39
  class QWenChat(Base):
 
 
 
 
 
40
  def chat(self, system, history, gen_conf):
41
  from http import HTTPStatus
42
- from dashscope import Generation
43
- # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
44
  history.insert(0, {"role": "system", "content": system})
45
  response = Generation.call(
46
- Generation.Models.qwen_turbo,
47
  messages=history,
48
  result_format='message'
49
  )
 
19
 
20
 
21
  class Base(ABC):
22
+ def __init__(self, key, model_name):
23
+ pass
24
+
25
  def chat(self, system, history, gen_conf):
26
  raise NotImplementedError("Please implement encode method!")
27
 
28
 
29
  class GptTurbo(Base):
30
+ def __init__(self, key, model_name="gpt-3.5-turbo"):
31
+ self.client = OpenAI(api_key=key)
32
+ self.model_name = model_name
33
 
34
  def chat(self, system, history, gen_conf):
35
  history.insert(0, {"role": "system", "content": system})
36
  res = self.client.chat.completions.create(
37
+ model=self.model_name,
38
  messages=history,
39
  **gen_conf)
40
  return res.choices[0].message.content.strip()
41
 
42
 
43
+ from dashscope import Generation
44
  class QWenChat(Base):
45
+ def __init__(self, key, model_name=Generation.Models.qwen_turbo):
46
+ import dashscope
47
+ dashscope.api_key = key
48
+ self.model_name = model_name
49
+
50
  def chat(self, system, history, gen_conf):
51
  from http import HTTPStatus
 
 
52
  history.insert(0, {"role": "system", "content": system})
53
  response = Generation.call(
54
+ self.model_name,
55
  messages=history,
56
  result_format='message'
57
  )
rag/llm/cv_model.py CHANGED
@@ -28,6 +28,8 @@ class Base(ABC):
28
  raise NotImplementedError("Please implement encode method!")
29
 
30
  def image2base64(self, image):
 
 
31
  if isinstance(image, BytesIO):
32
  return base64.b64encode(image.getvalue()).decode("utf-8")
33
  buffered = BytesIO()
@@ -59,7 +61,7 @@ class Base(ABC):
59
 
60
  class GptV4(Base):
61
  def __init__(self, key, model_name="gpt-4-vision-preview"):
62
- self.client = OpenAI(key)
63
  self.model_name = model_name
64
 
65
  def describe(self, image, max_tokens=300):
 
28
  raise NotImplementedError("Please implement encode method!")
29
 
30
  def image2base64(self, image):
31
+ if isinstance(image, bytes):
32
+ return base64.b64encode(image).decode("utf-8")
33
  if isinstance(image, BytesIO):
34
  return base64.b64encode(image.getvalue()).decode("utf-8")
35
  buffered = BytesIO()
 
61
 
62
  class GptV4(Base):
63
  def __init__(self, key, model_name="gpt-4-vision-preview"):
64
+ self.client = OpenAI(api_key = key)
65
  self.model_name = model_name
66
 
67
  def describe(self, image, max_tokens=300):
rag/nlp/search.py CHANGED
@@ -187,9 +187,10 @@ class Dealer:
187
  if len(t) < 5: continue
188
  idx.append(i)
189
  pieces_.append(t)
 
190
  if not pieces_: return answer
191
 
192
- ans_v = embd_mdl.encode(pieces_)
193
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
194
  len(ans_v[0]), len(chunk_v[0]))
195
 
@@ -219,7 +220,7 @@ class Dealer:
219
  Dealer.trans2floats(
220
  sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
221
  if not ins_embd:
222
- return []
223
  ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
224
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
225
  ins_embd,
@@ -235,6 +236,8 @@ class Dealer:
235
 
236
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
237
  vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
 
 
238
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
239
  "question": question, "vector": True,
240
  "similarity": similarity_threshold}
@@ -243,7 +246,7 @@ class Dealer:
243
  sim, tsim, vsim = self.rerank(
244
  sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
245
  idx = np.argsort(sim * -1)
246
- ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
247
  dim = len(sres.query_vector)
248
  start_idx = (page - 1) * page_size
249
  for i in idx:
 
187
  if len(t) < 5: continue
188
  idx.append(i)
189
  pieces_.append(t)
190
+ es_logger.info("{} => {}".format(answer, pieces_))
191
  if not pieces_: return answer
192
 
193
+ ans_v, c = embd_mdl.encode(pieces_)
194
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
195
  len(ans_v[0]), len(chunk_v[0]))
196
 
 
220
  Dealer.trans2floats(
221
  sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
222
  if not ins_embd:
223
+ return [], [], []
224
  ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
225
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
226
  ins_embd,
 
236
 
237
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
238
  vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
239
+ ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
240
+ if not question: return ranks
241
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
242
  "question": question, "vector": True,
243
  "similarity": similarity_threshold}
 
246
  sim, tsim, vsim = self.rerank(
247
  sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
248
  idx = np.argsort(sim * -1)
249
+
250
  dim = len(sres.query_vector)
251
  start_idx = (page - 1) * page_size
252
  for i in idx:
rag/svr/parse_user_docs.py CHANGED
@@ -78,6 +78,7 @@ def chuck_doc(name, binary, cvmdl=None):
78
  field = TextChunker.Fields()
79
  field.text_chunks = [(txt, binary)]
80
  field.table_chunks = []
 
81
 
82
  return TextChunker()(binary)
83
 
@@ -161,9 +162,9 @@ def build(row, cvmdl):
161
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
162
  output_buffer = BytesIO()
163
  docs = []
164
- md5 = hashlib.md5()
165
  for txt, img in obj.text_chunks:
166
  d = copy.deepcopy(doc)
 
167
  md5.update((txt + str(d["doc_id"])).encode("utf-8"))
168
  d["_id"] = md5.hexdigest()
169
  d["content_ltks"] = huqie.qie(txt)
@@ -186,6 +187,7 @@ def build(row, cvmdl):
186
  for i, txt in enumerate(arr):
187
  d = copy.deepcopy(doc)
188
  d["content_ltks"] = huqie.qie(txt)
 
189
  md5.update((txt + str(d["doc_id"])).encode("utf-8"))
190
  d["_id"] = md5.hexdigest()
191
  if not img:
@@ -226,9 +228,6 @@ def embedding(docs, mdl):
226
 
227
 
228
  def main(comm, mod):
229
- global model
230
- from rag.llm import HuEmbedding
231
- model = HuEmbedding()
232
  tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
233
  tm = findMaxTm(tm_fnm)
234
  rows = collect(comm, mod, tm)
@@ -260,13 +259,14 @@ def main(comm, mod):
260
  set_progress(r["id"], random.randint(70, 95) / 100.,
261
  "Finished embedding! Start to build index!")
262
  init_kb(r)
 
263
  es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
264
  if es_r:
265
  set_progress(r["id"], -1, "Index failure!")
266
  cron_logger.error(str(es_r))
267
  else:
268
  set_progress(r["id"], 1., "Done!")
269
- DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
270
  cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
271
 
272
  tmf.write(str(r["update_time"]) + "\n")
 
78
  field = TextChunker.Fields()
79
  field.text_chunks = [(txt, binary)]
80
  field.table_chunks = []
81
+ return field
82
 
83
  return TextChunker()(binary)
84
 
 
162
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
163
  output_buffer = BytesIO()
164
  docs = []
 
165
  for txt, img in obj.text_chunks:
166
  d = copy.deepcopy(doc)
167
+ md5 = hashlib.md5()
168
  md5.update((txt + str(d["doc_id"])).encode("utf-8"))
169
  d["_id"] = md5.hexdigest()
170
  d["content_ltks"] = huqie.qie(txt)
 
187
  for i, txt in enumerate(arr):
188
  d = copy.deepcopy(doc)
189
  d["content_ltks"] = huqie.qie(txt)
190
+ md5 = hashlib.md5()
191
  md5.update((txt + str(d["doc_id"])).encode("utf-8"))
192
  d["_id"] = md5.hexdigest()
193
  if not img:
 
228
 
229
 
230
  def main(comm, mod):
 
 
 
231
  tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
232
  tm = findMaxTm(tm_fnm)
233
  rows = collect(comm, mod, tm)
 
259
  set_progress(r["id"], random.randint(70, 95) / 100.,
260
  "Finished embedding! Start to build index!")
261
  init_kb(r)
262
+ chunk_count = len(set([c["_id"] for c in cks]))
263
  es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
264
  if es_r:
265
  set_progress(r["id"], -1, "Index failure!")
266
  cron_logger.error(str(es_r))
267
  else:
268
  set_progress(r["id"], 1., "Done!")
269
+ DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm)
270
  cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
271
 
272
  tmf.write(str(r["update_time"]) + "\n")