KevinHuSh commited on
Commit
c372afe
·
1 Parent(s): 6b8fc2c

change licence (#28)

Browse files

* add front end code

* change licence

rag/llm/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  #
2
- # Copyright 2019 The FATE Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
1
  #
2
+ # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
rag/llm/chat_model.py CHANGED
@@ -1,5 +1,5 @@
1
  #
2
- # Copyright 2019 The FATE Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
1
  #
2
+ # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
rag/llm/cv_model.py CHANGED
@@ -1,5 +1,5 @@
1
  #
2
- # Copyright 2019 The FATE Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
1
  #
2
+ # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
rag/llm/embedding_model.py CHANGED
@@ -1,5 +1,5 @@
1
  #
2
- # Copyright 2019 The FATE Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
1
  #
2
+ # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
rag/nlp/search.py CHANGED
@@ -1,8 +1,11 @@
1
  # -*- coding: utf-8 -*-
 
2
  import re
3
  from elasticsearch_dsl import Q, Search, A
4
  from typing import List, Optional, Tuple, Dict, Union
5
  from dataclasses import dataclass
 
 
6
  from rag.utils import rmSpace
7
  from rag.nlp import huqie, query
8
  import numpy as np
@@ -34,30 +37,30 @@ class Dealer:
34
  group_docs: List[List] = None
35
 
36
  def _vector(self, txt, sim=0.8, topk=10):
 
37
  return {
38
- "field": "q_vec",
39
  "k": topk,
40
  "similarity": sim,
41
  "num_candidates": 1000,
42
- "query_vector": self.emb_mdl.encode_queries(txt)
43
  }
44
 
45
  def search(self, req, idxnm, tks_num=3):
46
- keywords = []
47
  qst = req.get("question", "")
48
-
49
  bqry, keywords = self.qryr.question(qst)
50
  if req.get("kb_ids"):
51
  bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
52
- bqry.filter.append(Q("exists", field="q_tks"))
 
53
  bqry.boost = 0.05
54
- print(bqry)
55
 
56
  s = Search()
57
  pg = int(req.get("page", 1)) - 1
58
  ps = int(req.get("size", 1000))
59
- src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
60
- "image_id", "doc_id", "q_vec"])
 
61
 
62
  s = s.query(bqry)[pg * ps:(pg + 1) * ps]
63
  s = s.highlight("content_ltks")
@@ -66,22 +69,24 @@ class Dealer:
66
  s = s.sort(
67
  {"create_time": {"order": "desc", "unmapped_type": "date"}})
68
 
69
- s = s.highlight_options(
70
- fragment_size=120,
71
- number_of_fragments=5,
72
- boundary_scanner_locale="zh-CN",
73
- boundary_scanner="SENTENCE",
74
- boundary_chars=",./;:\\!(),。?:!……()——、"
75
- )
 
76
  s = s.to_dict()
77
  q_vec = []
78
  if req.get("vector"):
79
  s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
80
  s["knn"]["filter"] = bqry.to_dict()
81
- del s["highlight"]
82
  q_vec = s["knn"]["query_vector"]
 
83
  res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
84
- print("TOTAL: ", self.es.getTotal(res))
85
  if self.es.getTotal(res) == 0 and "knn" in s:
86
  bqry, _ = self.qryr.question(qst, min_match="10%")
87
  if req.get("kb_ids"):
@@ -109,8 +114,7 @@ class Dealer:
109
  query_vector=q_vec,
110
  aggregation=aggs,
111
  highlight=self.getHighlight(res),
112
- field=self.getFields(res, ["docnm_kwd", "content_ltks",
113
- "kb_id", "image_id", "doc_id", "q_vec"]),
114
  keywords=list(kwds)
115
  )
116
 
@@ -237,14 +241,4 @@ class Dealer:
237
  return sim
238
 
239
 
240
- if __name__ == "__main__":
241
- from util import es_conn
242
- SE = Dealer(es_conn.HuEs("infiniflow"))
243
- qs = [
244
- "胡凯",
245
- ""
246
- ]
247
- for q in qs:
248
- print(">>>>>>>>>>>>>>>>>>>>", q)
249
- print(SE.search(
250
- {"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
 
1
  # -*- coding: utf-8 -*-
2
+ import json
3
  import re
4
  from elasticsearch_dsl import Q, Search, A
5
  from typing import List, Optional, Tuple, Dict, Union
6
  from dataclasses import dataclass
7
+
8
+ from rag.settings import es_logger
9
  from rag.utils import rmSpace
10
  from rag.nlp import huqie, query
11
  import numpy as np
 
37
  group_docs: List[List] = None
38
 
39
  def _vector(self, txt, sim=0.8, topk=10):
40
+ qv = self.emb_mdl.encode_queries(txt)
41
  return {
42
+ "field": "q_%d_vec"%len(qv),
43
  "k": topk,
44
  "similarity": sim,
45
  "num_candidates": 1000,
46
+ "query_vector": qv
47
  }
48
 
49
  def search(self, req, idxnm, tks_num=3):
 
50
  qst = req.get("question", "")
 
51
  bqry, keywords = self.qryr.question(qst)
52
  if req.get("kb_ids"):
53
  bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
54
+ if req.get("doc_ids"):
55
+ bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
56
  bqry.boost = 0.05
 
57
 
58
  s = Search()
59
  pg = int(req.get("page", 1)) - 1
60
  ps = int(req.get("size", 1000))
61
+ src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id",
62
+ "image_id", "doc_id", "q_512_vec", "q_768_vec",
63
+ "q_1024_vec", "q_1536_vec"])
64
 
65
  s = s.query(bqry)[pg * ps:(pg + 1) * ps]
66
  s = s.highlight("content_ltks")
 
69
  s = s.sort(
70
  {"create_time": {"order": "desc", "unmapped_type": "date"}})
71
 
72
+ if qst:
73
+ s = s.highlight_options(
74
+ fragment_size=120,
75
+ number_of_fragments=5,
76
+ boundary_scanner_locale="zh-CN",
77
+ boundary_scanner="SENTENCE",
78
+ boundary_chars=",./;:\\!(),。?:!……()——、"
79
+ )
80
  s = s.to_dict()
81
  q_vec = []
82
  if req.get("vector"):
83
  s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
84
  s["knn"]["filter"] = bqry.to_dict()
85
+ if "highlight" in s: del s["highlight"]
86
  q_vec = s["knn"]["query_vector"]
87
+ es_logger.info("【Q】: {}".format(json.dumps(s)))
88
  res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
89
+ es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
90
  if self.es.getTotal(res) == 0 and "knn" in s:
91
  bqry, _ = self.qryr.question(qst, min_match="10%")
92
  if req.get("kb_ids"):
 
114
  query_vector=q_vec,
115
  aggregation=aggs,
116
  highlight=self.getHighlight(res),
117
+ field=self.getFields(res, src),
 
118
  keywords=list(kwds)
119
  )
120
 
 
241
  return sim
242
 
243
 
244
+
 
 
 
 
 
 
 
 
 
 
rag/nlp/term_weight.py CHANGED
@@ -62,7 +62,7 @@ class Dealer:
62
  return set(res.keys())
63
  return res
64
 
65
- fnm = os.path.join(get_project_base_directory(), "res")
66
  self.ne, self.df = {}, {}
67
  try:
68
  self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
 
62
  return set(res.keys())
63
  return res
64
 
65
+ fnm = os.path.join(get_project_base_directory(), "rag/res")
66
  self.ne, self.df = {}, {}
67
  try:
68
  self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
rag/settings.py CHANGED
@@ -1,5 +1,5 @@
1
  #
2
- # Copyright 2019 The FATE Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
1
  #
2
+ # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
rag/svr/parse_user_docs.py CHANGED
@@ -1,5 +1,5 @@
1
  #
2
- # Copyright 2019 The FATE Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import json
17
  import logging
18
  import os
@@ -108,17 +109,17 @@ def build(row, cvmdl):
108
  (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
109
  return []
110
 
111
- res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
112
- if ELASTICSEARCH.getTotal(res) > 0:
113
- ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
114
- scripts="""
115
- if(!ctx._source.kb_id.contains('%s'))
116
- ctx._source.kb_id.add('%s');
117
- """ % (str(row["kb_id"]), str(row["kb_id"])),
118
- idxnm=search.index_name(row["tenant_id"])
119
- )
120
- set_progress(row["id"], 1, "Done")
121
- return []
122
 
123
  random.seed(time.time())
124
  set_progress(row["id"], random.randint(0, 20) /
@@ -155,8 +156,7 @@ def build(row, cvmdl):
155
  "doc_id": row["id"],
156
  "kb_id": [str(row["kb_id"])],
157
  "docnm_kwd": os.path.split(row["location"])[-1],
158
- "title_tks": huqie.qie(row["name"]),
159
- "updated_at": str(row["update_time"]).replace("T", " ")[:19]
160
  }
161
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
162
  output_buffer = BytesIO()
@@ -179,6 +179,7 @@ def build(row, cvmdl):
179
 
180
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
181
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
 
182
  docs.append(d)
183
 
184
  for arr, img in obj.table_chunks:
@@ -193,6 +194,7 @@ def build(row, cvmdl):
193
  img.save(output_buffer, format='JPEG')
194
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
195
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
 
196
  docs.append(d)
197
  set_progress(row["id"], random.randint(60, 70) /
198
  100., "Continue embedding the content.")
@@ -218,23 +220,11 @@ def embedding(docs, mdl):
218
  vects = 0.1 * tts + 0.9 * cnts
219
  assert len(vects) == len(docs)
220
  for i, d in enumerate(docs):
221
- d["q_vec"] = vects[i].tolist()
 
222
  return tk_count
223
 
224
 
225
- def model_instance(tenant_id, llm_type):
226
- model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
227
- if not model_config:
228
- model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
229
- else: model_config = model_config[0].to_dict()
230
- if llm_type == LLMType.EMBEDDING:
231
- if model_config["llm_factory"] not in EmbeddingModel: return
232
- return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
233
- if llm_type == LLMType.IMAGE2TEXT:
234
- if model_config["llm_factory"] not in CvModel: return
235
- return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
236
-
237
-
238
  def main(comm, mod):
239
  global model
240
  from rag.llm import HuEmbedding
@@ -247,12 +237,12 @@ def main(comm, mod):
247
 
248
  tmf = open(tm_fnm, "a+")
249
  for _, r in rows.iterrows():
250
- embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING)
251
  if not embd_mdl:
252
  set_progress(r["id"], -1, "Can't find embedding model!")
253
  cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
254
  continue
255
- cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
256
  st_tm = timer()
257
  cks = build(r, cv_mdl)
258
  if not cks:
 
1
  #
2
+ # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import datetime
17
  import json
18
  import logging
19
  import os
 
109
  (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
110
  return []
111
 
112
+ # res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
113
+ # if ELASTICSEARCH.getTotal(res) > 0:
114
+ # ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
115
+ # scripts="""
116
+ # if(!ctx._source.kb_id.contains('%s'))
117
+ # ctx._source.kb_id.add('%s');
118
+ # """ % (str(row["kb_id"]), str(row["kb_id"])),
119
+ # idxnm=search.index_name(row["tenant_id"])
120
+ # )
121
+ # set_progress(row["id"], 1, "Done")
122
+ # return []
123
 
124
  random.seed(time.time())
125
  set_progress(row["id"], random.randint(0, 20) /
 
156
  "doc_id": row["id"],
157
  "kb_id": [str(row["kb_id"])],
158
  "docnm_kwd": os.path.split(row["location"])[-1],
159
+ "title_tks": huqie.qie(row["name"])
 
160
  }
161
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
162
  output_buffer = BytesIO()
 
179
 
180
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
181
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
182
+ d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
183
  docs.append(d)
184
 
185
  for arr, img in obj.table_chunks:
 
194
  img.save(output_buffer, format='JPEG')
195
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
196
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
197
+ d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
198
  docs.append(d)
199
  set_progress(row["id"], random.randint(60, 70) /
200
  100., "Continue embedding the content.")
 
220
  vects = 0.1 * tts + 0.9 * cnts
221
  assert len(vects) == len(docs)
222
  for i, d in enumerate(docs):
223
+ v = vects[i].tolist()
224
+ d["q_%d_vec"%len(v)] = v
225
  return tk_count
226
 
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def main(comm, mod):
229
  global model
230
  from rag.llm import HuEmbedding
 
237
 
238
  tmf = open(tm_fnm, "a+")
239
  for _, r in rows.iterrows():
240
+ embd_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.EMBEDDING)
241
  if not embd_mdl:
242
  set_progress(r["id"], -1, "Can't find embedding model!")
243
  cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
244
  continue
245
+ cv_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
246
  st_tm = timer()
247
  cks = build(r, cv_mdl)
248
  if not cks:
rag/utils/es_conn.py CHANGED
@@ -241,6 +241,26 @@ class HuEs:
241
  es_logger.error("ES search timeout for 3 times!")
242
  raise Exception("ES search timeout.")
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def updateByQuery(self, q, d):
245
  ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
246
  scripts = ""
 
241
  es_logger.error("ES search timeout for 3 times!")
242
  raise Exception("ES search timeout.")
243
 
244
+ def get(self, doc_id, idxnm=None):
245
+ for i in range(3):
246
+ try:
247
+ res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
248
+ id=doc_id)
249
+ if str(res.get("timed_out", "")).lower() == "true":
250
+ raise Exception("Es Timeout.")
251
+ return res
252
+ except Exception as e:
253
+ es_logger.error(
254
+ "ES get exception: " +
255
+ str(e) +
256
+ "【Q】:" +
257
+ doc_id)
258
+ if str(e).find("Timeout") > 0:
259
+ continue
260
+ raise e
261
+ es_logger.error("ES search timeout for 3 times!")
262
+ raise Exception("ES search timeout.")
263
+
264
  def updateByQuery(self, q, d):
265
  ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
266
  scripts = ""