KevinHuSh commited on
Commit
d0db329
·
1 Parent(s): cdd9565

add llm API (#19)

Browse files

* add llm API

* refine llm API

python/conf/mapping.json CHANGED
@@ -121,7 +121,6 @@
121
  "match": "*_vec",
122
  "mapping": {
123
  "type": "dense_vector",
124
- "dims": 1024,
125
  "index": true,
126
  "similarity": "cosine"
127
  }
 
121
  "match": "*_vec",
122
  "mapping": {
123
  "type": "dense_vector",
 
124
  "index": true,
125
  "similarity": "cosine"
126
  }
python/conf/sys.cnf CHANGED
@@ -1,10 +1,9 @@
1
  [infiniflow]
2
  es=http://es01:9200
3
- pgdb_usr=root
4
- pgdb_pwd=infiniflow_docgpt
5
- pgdb_host=postgres
6
- pgdb_port=5432
7
  minio_host=minio:9000
8
- minio_usr=infiniflow
9
- minio_pwd=infiniflow_docgpt
10
-
 
1
  [infiniflow]
2
  es=http://es01:9200
3
+ postgres_user=root
4
+ postgres_password=infiniflow_docgpt
5
+ postgres_host=postgres
6
+ postgres_port=5432
7
  minio_host=minio:9000
8
+ minio_user=infiniflow
9
+ minio_password=infiniflow_docgpt
 
python/llm/__init__.py CHANGED
@@ -1,2 +1,21 @@
1
- from .embedding_model import HuEmbedding
2
- from .chat_model import GptTurbo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .embedding_model import *
3
+ from .chat_model import *
4
+ from .cv_model import *
5
+
6
+ EmbeddingModel = None
7
+ ChatModel = None
8
+ CvModel = None
9
+
10
+
11
+ if os.environ.get("OPENAI_API_KEY"):
12
+ EmbeddingModel = GptEmbed()
13
+ ChatModel = GptTurbo()
14
+ CvModel = GptV4()
15
+
16
+ elif os.environ.get("DASHSCOPE_API_KEY"):
17
+ EmbeddingModel = QWenEmbd()
18
+ ChatModel = QWenChat()
19
+ CvModel = QWenCV()
20
+ else:
21
+ EmbeddingModel = HuEmbedding()
python/llm/chat_model.py CHANGED
@@ -1,7 +1,8 @@
1
  from abc import ABC
2
- import openapi
3
  import os
4
 
 
5
  class Base(ABC):
6
  def chat(self, system, history, gen_conf):
7
  raise NotImplementedError("Please implement encode method!")
@@ -9,26 +10,27 @@ class Base(ABC):
9
 
10
  class GptTurbo(Base):
11
  def __init__(self):
12
- openapi.api_key = os.environ["OPENAPI_KEY"]
13
 
14
  def chat(self, system, history, gen_conf):
15
  history.insert(0, {"role": "system", "content": system})
16
- res = openapi.ChatCompletion.create(model="gpt-3.5-turbo",
17
- messages=history,
18
- **gen_conf)
 
19
  return res.choices[0].message.content.strip()
20
 
21
 
22
- class QWen(Base):
23
  def chat(self, system, history, gen_conf):
24
  from http import HTTPStatus
25
  from dashscope import Generation
26
- from dashscope.api_entities.dashscope_response import Role
27
  # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
 
28
  response = Generation.call(
29
- Generation.Models.qwen_turbo,
30
- messages=messages,
31
- result_format='message'
32
  )
33
  if response.status_code == HTTPStatus.OK:
34
  return response.output.choices[0]['message']['content']
 
1
  from abc import ABC
2
+ from openai import OpenAI
3
  import os
4
 
5
+
6
  class Base(ABC):
7
  def chat(self, system, history, gen_conf):
8
  raise NotImplementedError("Please implement encode method!")
 
10
 
11
  class GptTurbo(Base):
12
  def __init__(self):
13
+ self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
14
 
15
  def chat(self, system, history, gen_conf):
16
  history.insert(0, {"role": "system", "content": system})
17
+ res = self.client.chat.completions.create(
18
+ model="gpt-3.5-turbo",
19
+ messages=history,
20
+ **gen_conf)
21
  return res.choices[0].message.content.strip()
22
 
23
 
24
+ class QWenChat(Base):
25
  def chat(self, system, history, gen_conf):
26
  from http import HTTPStatus
27
  from dashscope import Generation
 
28
  # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
29
+ history.insert(0, {"role": "system", "content": system})
30
  response = Generation.call(
31
+ Generation.Models.qwen_turbo,
32
+ messages=history,
33
+ result_format='message'
34
  )
35
  if response.status_code == HTTPStatus.OK:
36
  return response.output.choices[0]['message']['content']
python/llm/cv_model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from openai import OpenAI
3
+ import os
4
+ import base64
5
+ from io import BytesIO
6
+
7
+
8
+ class Base(ABC):
9
+ def describe(self, image, max_tokens=300):
10
+ raise NotImplementedError("Please implement encode method!")
11
+
12
+ def image2base64(self, image):
13
+ if isinstance(image, BytesIO):
14
+ return base64.b64encode(image.getvalue()).decode("utf-8")
15
+ buffered = BytesIO()
16
+ try:
17
+ image.save(buffered, format="JPEG")
18
+ except Exception as e:
19
+ image.save(buffered, format="PNG")
20
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
21
+
22
+ def prompt(self, b64):
23
+ return [
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {
28
+ "type": "text",
29
+ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
30
+ },
31
+ {
32
+ "type": "image_url",
33
+ "image_url": {
34
+ "url": f"data:image/jpeg;base64,{b64}"
35
+ },
36
+ },
37
+ ],
38
+ }
39
+ ]
40
+
41
+
42
+ class GptV4(Base):
43
+ def __init__(self):
44
+ self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
45
+
46
+ def describe(self, image, max_tokens=300):
47
+ b64 = self.image2base64(image)
48
+
49
+ res = self.client.chat.completions.create(
50
+ model="gpt-4-vision-preview",
51
+ messages=self.prompt(b64),
52
+ max_tokens=max_tokens,
53
+ )
54
+ return res.choices[0].message.content.strip()
55
+
56
+
57
+ class QWenCV(Base):
58
+ def describe(self, image, max_tokens=300):
59
+ from http import HTTPStatus
60
+ from dashscope import MultiModalConversation
61
+ # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
62
+ response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
63
+ messages=self.prompt(self.image2base64(image)))
64
+ if response.status_code == HTTPStatus.OK:
65
+ return response.output.choices[0]['message']['content']
66
+ return response.message
python/llm/embedding_model.py CHANGED
@@ -1,8 +1,11 @@
1
  from abc import ABC
 
2
  from FlagEmbedding import FlagModel
3
  import torch
 
4
  import numpy as np
5
 
 
6
  class Base(ABC):
7
  def encode(self, texts: list, batch_size=32):
8
  raise NotImplementedError("Please implement encode method!")
@@ -22,11 +25,37 @@ class HuEmbedding(Base):
22
 
23
  """
24
  self.model = FlagModel("BAAI/bge-large-zh-v1.5",
25
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
26
- use_fp16=torch.cuda.is_available())
27
 
28
  def encode(self, texts: list, batch_size=32):
29
  res = []
30
  for i in range(0, len(texts), batch_size):
31
- res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
32
  return np.array(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC
2
+ from openai import OpenAI
3
  from FlagEmbedding import FlagModel
4
  import torch
5
+ import os
6
  import numpy as np
7
 
8
+
9
  class Base(ABC):
10
  def encode(self, texts: list, batch_size=32):
11
  raise NotImplementedError("Please implement encode method!")
 
25
 
26
  """
27
  self.model = FlagModel("BAAI/bge-large-zh-v1.5",
28
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
29
+ use_fp16=torch.cuda.is_available())
30
 
31
  def encode(self, texts: list, batch_size=32):
32
  res = []
33
  for i in range(0, len(texts), batch_size):
34
+ res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
35
  return np.array(res)
36
+
37
+
38
+ class GptEmbed(Base):
39
+ def __init__(self):
40
+ self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
41
+
42
+ def encode(self, texts: list, batch_size=32):
43
+ res = self.client.embeddings.create(input=texts,
44
+ model="text-embedding-ada-002")
45
+ return [d["embedding"] for d in res["data"]]
46
+
47
+
48
+ class QWenEmbd(Base):
49
+ def encode(self, texts: list, batch_size=32, text_type="document"):
50
+ # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
51
+ import dashscope
52
+ from http import HTTPStatus
53
+ res = []
54
+ for txt in texts:
55
+ resp = dashscope.TextEmbedding.call(
56
+ model=dashscope.TextEmbedding.Models.text_embedding_v2,
57
+ input=txt[:2048],
58
+ text_type=text_type
59
+ )
60
+ res.append(resp["output"]["embeddings"][0]["embedding"])
61
+ return res
python/nlp/huchunk.py CHANGED
@@ -372,7 +372,9 @@ class PptChunker(HuChunker):
372
 
373
  def __call__(self, fnm):
374
  from pptx import Presentation
375
- ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm))
 
 
376
  flds = self.Fields()
377
  flds.text_chunks = []
378
  for slide in ppt.slides:
@@ -398,7 +400,8 @@ class TextChunker(HuChunker):
398
  mime = magic.Magic(mime=True)
399
  if isinstance(file_path, str):
400
  file_type = mime.from_file(file_path)
401
- else:file_type = mime.from_buffer(file_path)
 
402
  if 'text' in file_type:
403
  return False
404
  else:
@@ -406,7 +409,8 @@ class TextChunker(HuChunker):
406
 
407
  def __call__(self, fnm):
408
  flds = self.Fields()
409
- if self.is_binary_file(fnm):return flds
 
410
  with open(fnm, "r") as f:
411
  txt = f.read()
412
  flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
 
372
 
373
  def __call__(self, fnm):
374
  from pptx import Presentation
375
+ ppt = Presentation(fnm) if isinstance(
376
+ fnm, str) else Presentation(
377
+ BytesIO(fnm))
378
  flds = self.Fields()
379
  flds.text_chunks = []
380
  for slide in ppt.slides:
 
400
  mime = magic.Magic(mime=True)
401
  if isinstance(file_path, str):
402
  file_type = mime.from_file(file_path)
403
+ else:
404
+ file_type = mime.from_buffer(file_path)
405
  if 'text' in file_type:
406
  return False
407
  else:
 
409
 
410
  def __call__(self, fnm):
411
  flds = self.Fields()
412
+ if self.is_binary_file(fnm):
413
+ return flds
414
  with open(fnm, "r") as f:
415
  txt = f.read()
416
  flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
python/nlp/search.py CHANGED
@@ -1,6 +1,6 @@
1
  import re
2
- from elasticsearch_dsl import Q,Search,A
3
- from typing import List, Optional, Tuple,Dict, Union
4
  from dataclasses import dataclass
5
  from util import setup_logging, rmSpace
6
  from nlp import huqie, query
@@ -9,18 +9,24 @@ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
9
  import numpy as np
10
  from copy import deepcopy
11
 
12
- def index_name(uid):return f"docgpt_{uid}"
 
 
13
 
14
  class Dealer:
15
  def __init__(self, es, emb_mdl):
16
  self.qryr = query.EsQueryer(es)
17
- self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
 
 
 
 
18
  self.es = es
19
  self.emb_mdl = emb_mdl
20
 
21
  @dataclass
22
  class SearchResult:
23
- total:int
24
  ids: List[str]
25
  query_vector: List[float] = None
26
  field: Optional[Dict] = None
@@ -42,71 +48,78 @@ class Dealer:
42
  keywords = []
43
  qst = req.get("question", "")
44
 
45
- bqry,keywords = self.qryr.question(qst)
46
- if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
 
47
  bqry.filter.append(Q("exists", field="q_tks"))
48
  bqry.boost = 0.05
49
  print(bqry)
50
 
51
  s = Search()
52
- pg = int(req.get("page", 1))-1
53
  ps = int(req.get("size", 1000))
54
  src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
55
  "image_id", "doc_id", "q_vec"])
56
 
57
- s = s.query(bqry)[pg*ps:(pg+1)*ps]
58
  s = s.highlight("content_ltks")
59
  s = s.highlight("title_ltks")
60
- if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
 
 
61
 
62
  s = s.highlight_options(
63
- fragment_size = 120,
64
- number_of_fragments=5,
65
- boundary_scanner_locale="zh-CN",
66
- boundary_scanner="SENTENCE",
67
- boundary_chars=",./;:\\!(),。?:!……()——、"
68
- )
69
  s = s.to_dict()
70
  q_vec = []
71
- if req.get("vector"):
72
  s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
73
  s["knn"]["filter"] = bqry.to_dict()
74
  del s["highlight"]
75
  q_vec = s["knn"]["query_vector"]
76
- res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
77
  print("TOTAL: ", self.es.getTotal(res))
78
  if self.es.getTotal(res) == 0 and "knn" in s:
79
- bqry,_ = self.qryr.question(qst, min_match="10%")
80
- if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
 
81
  s["query"] = bqry.to_dict()
82
  s["knn"]["filter"] = bqry.to_dict()
83
  s["knn"]["similarity"] = 0.7
84
- res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
85
 
86
  kwds = set([])
87
  for k in keywords:
88
  kwds.add(k)
89
  for kk in huqie.qieqie(k).split(" "):
90
- if len(kk) < 2:continue
91
- if kk in kwds:continue
 
 
92
  kwds.add(kk)
93
 
94
  aggs = self.getAggregation(res, "docnm_kwd")
95
 
96
  return self.SearchResult(
97
- total = self.es.getTotal(res),
98
- ids = self.es.getDocIds(res),
99
- query_vector = q_vec,
100
- aggregation = aggs,
101
- highlight = self.getHighlight(res),
102
- field = self.getFields(res, ["docnm_kwd", "content_ltks",
103
- "kb_id","image_id", "doc_id", "q_vec"]),
104
- keywords = list(kwds)
105
  )
106
 
107
  def getAggregation(self, res, g):
108
- if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
109
- bkts = res["aggregations"]["aggs_"+g]["buckets"]
 
110
  return [(b["key"], b["doc_count"]) for b in bkts]
111
 
112
  def getHighlight(self, res):
@@ -114,8 +127,11 @@ class Dealer:
114
  eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
115
  r = []
116
  for t in line.split(" "):
117
- if not t:continue
118
- if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
 
 
 
119
  r.append(t)
120
  r = "".join(r)
121
  return r
@@ -123,66 +139,76 @@ class Dealer:
123
  ans = {}
124
  for d in res["hits"]["hits"]:
125
  hlts = d.get("highlight")
126
- if not hlts:continue
 
127
  ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
128
  return ans
129
 
130
  def getFields(self, sres, flds):
131
  res = {}
132
- if not flds:return {}
133
- for d in self.es.getSource(sres):
134
- m = {n:d.get(n) for n in flds if d.get(n) is not None}
135
- for n,v in m.items():
136
- if type(v) == type([]):
 
137
  m[n] = "\t".join([str(vv) for vv in v])
138
  continue
139
- if type(v) != type(""):m[n] = str(m[n])
 
140
  m[n] = rmSpace(m[n])
141
 
142
- if m:res[d["id"]] = m
 
143
  return res
144
 
145
-
146
  @staticmethod
147
  def trans2floats(txt):
148
  return [float(t) for t in txt.split("\t")]
149
 
 
 
150
 
151
- def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
152
-
153
- ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
154
- ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
155
  s = 0
156
  e = 0
157
  res = ""
 
158
  def citeit():
159
  nonlocal s, e, ans, res
160
- if not ins_embd:return
 
161
  embd = self.emb_mdl.encode(ans[s: e])
162
- sim = self.qryr.hybrid_similarity(embd,
163
- ins_embd,
164
  huqie.qie(ans[s:e]).split(" "),
165
  ins_tw)
166
  print(ans[s: e], sim)
167
- mx = np.max(sim)*0.99
168
- if mx < 0.55:return
169
- cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
170
- for i in cita: res += f"@?{i}?@"
 
 
 
171
 
172
  return cita
173
 
174
  punct = set(";。?!!")
175
- if not self.qryr.isChinese(ans):
176
  punct.add("?")
177
  punct.add(".")
178
  while e < len(ans):
179
  if e - s < 12 or ans[e] not in punct:
180
  e += 1
181
  continue
182
- if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
 
183
  e += 1
184
  continue
185
- if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
186
  e += 1
187
  continue
188
  res += ans[s: e]
@@ -191,33 +217,36 @@ class Dealer:
191
  e += 1
192
  s = e
193
 
194
- if s< len(ans):
195
  res += ans[s:]
196
  citeit()
197
 
198
  return res
199
 
200
-
201
- def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
202
- ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
203
- if not ins_embd: return []
204
- ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
205
- #return CosineSimilarity([sres.query_vector], ins_embd)[0]
206
- sim = self.qryr.hybrid_similarity(sres.query_vector,
207
- ins_embd,
 
 
 
208
  huqie.qie(query).split(" "),
209
  ins_tw, tkweight, vtweight)
210
  return sim
211
 
212
 
213
-
214
- if __name__ == "__main__":
215
  from util import es_conn
216
  SE = Dealer(es_conn.HuEs("infiniflow"))
217
  qs = [
218
  "胡凯",
219
  ""
220
  ]
221
- for q in qs:
222
  print(">>>>>>>>>>>>>>>>>>>>", q)
223
- print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
 
 
1
  import re
2
+ from elasticsearch_dsl import Q, Search, A
3
+ from typing import List, Optional, Tuple, Dict, Union
4
  from dataclasses import dataclass
5
  from util import setup_logging, rmSpace
6
  from nlp import huqie, query
 
9
  import numpy as np
10
  from copy import deepcopy
11
 
12
+
13
+ def index_name(uid): return f"docgpt_{uid}"
14
+
15
 
16
  class Dealer:
17
  def __init__(self, es, emb_mdl):
18
  self.qryr = query.EsQueryer(es)
19
+ self.qryr.flds = [
20
+ "title_tks^10",
21
+ "title_sm_tks^5",
22
+ "content_ltks^2",
23
+ "content_sm_ltks"]
24
  self.es = es
25
  self.emb_mdl = emb_mdl
26
 
27
  @dataclass
28
  class SearchResult:
29
+ total: int
30
  ids: List[str]
31
  query_vector: List[float] = None
32
  field: Optional[Dict] = None
 
48
  keywords = []
49
  qst = req.get("question", "")
50
 
51
+ bqry, keywords = self.qryr.question(qst)
52
+ if req.get("kb_ids"):
53
+ bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
54
  bqry.filter.append(Q("exists", field="q_tks"))
55
  bqry.boost = 0.05
56
  print(bqry)
57
 
58
  s = Search()
59
+ pg = int(req.get("page", 1)) - 1
60
  ps = int(req.get("size", 1000))
61
  src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
62
  "image_id", "doc_id", "q_vec"])
63
 
64
+ s = s.query(bqry)[pg * ps:(pg + 1) * ps]
65
  s = s.highlight("content_ltks")
66
  s = s.highlight("title_ltks")
67
+ if not qst:
68
+ s = s.sort(
69
+ {"create_time": {"order": "desc", "unmapped_type": "date"}})
70
 
71
  s = s.highlight_options(
72
+ fragment_size=120,
73
+ number_of_fragments=5,
74
+ boundary_scanner_locale="zh-CN",
75
+ boundary_scanner="SENTENCE",
76
+ boundary_chars=",./;:\\!(),。?:!……()——、"
77
+ )
78
  s = s.to_dict()
79
  q_vec = []
80
+ if req.get("vector"):
81
  s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
82
  s["knn"]["filter"] = bqry.to_dict()
83
  del s["highlight"]
84
  q_vec = s["knn"]["query_vector"]
85
+ res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
86
  print("TOTAL: ", self.es.getTotal(res))
87
  if self.es.getTotal(res) == 0 and "knn" in s:
88
+ bqry, _ = self.qryr.question(qst, min_match="10%")
89
+ if req.get("kb_ids"):
90
+ bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
91
  s["query"] = bqry.to_dict()
92
  s["knn"]["filter"] = bqry.to_dict()
93
  s["knn"]["similarity"] = 0.7
94
+ res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
95
 
96
  kwds = set([])
97
  for k in keywords:
98
  kwds.add(k)
99
  for kk in huqie.qieqie(k).split(" "):
100
+ if len(kk) < 2:
101
+ continue
102
+ if kk in kwds:
103
+ continue
104
  kwds.add(kk)
105
 
106
  aggs = self.getAggregation(res, "docnm_kwd")
107
 
108
  return self.SearchResult(
109
+ total=self.es.getTotal(res),
110
+ ids=self.es.getDocIds(res),
111
+ query_vector=q_vec,
112
+ aggregation=aggs,
113
+ highlight=self.getHighlight(res),
114
+ field=self.getFields(res, ["docnm_kwd", "content_ltks",
115
+ "kb_id", "image_id", "doc_id", "q_vec"]),
116
+ keywords=list(kwds)
117
  )
118
 
119
  def getAggregation(self, res, g):
120
+ if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
121
+ return
122
+ bkts = res["aggregations"]["aggs_" + g]["buckets"]
123
  return [(b["key"], b["doc_count"]) for b in bkts]
124
 
125
  def getHighlight(self, res):
 
127
  eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
128
  r = []
129
  for t in line.split(" "):
130
+ if not t:
131
+ continue
132
+ if len(r) > 0 and len(
133
+ t) > 0 and r[-1][-1] in eng and t[0] in eng:
134
+ r.append(" ")
135
  r.append(t)
136
  r = "".join(r)
137
  return r
 
139
  ans = {}
140
  for d in res["hits"]["hits"]:
141
  hlts = d.get("highlight")
142
+ if not hlts:
143
+ continue
144
  ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
145
  return ans
146
 
147
  def getFields(self, sres, flds):
148
  res = {}
149
+ if not flds:
150
+ return {}
151
+ for d in self.es.getSource(sres):
152
+ m = {n: d.get(n) for n in flds if d.get(n) is not None}
153
+ for n, v in m.items():
154
+ if isinstance(v, type([])):
155
  m[n] = "\t".join([str(vv) for vv in v])
156
  continue
157
+ if not isinstance(v, type("")):
158
+ m[n] = str(m[n])
159
  m[n] = rmSpace(m[n])
160
 
161
+ if m:
162
+ res[d["id"]] = m
163
  return res
164
 
 
165
  @staticmethod
166
  def trans2floats(txt):
167
  return [float(t) for t in txt.split("\t")]
168
 
169
+ def insert_citations(self, ans, top_idx, sres,
170
+ vfield="q_vec", cfield="content_ltks"):
171
 
172
+ ins_embd = [Dealer.trans2floats(
173
+ sres.field[sres.ids[i]][vfield]) for i in top_idx]
174
+ ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
 
175
  s = 0
176
  e = 0
177
  res = ""
178
+
179
  def citeit():
180
  nonlocal s, e, ans, res
181
+ if not ins_embd:
182
+ return
183
  embd = self.emb_mdl.encode(ans[s: e])
184
+ sim = self.qryr.hybrid_similarity(embd,
185
+ ins_embd,
186
  huqie.qie(ans[s:e]).split(" "),
187
  ins_tw)
188
  print(ans[s: e], sim)
189
+ mx = np.max(sim) * 0.99
190
+ if mx < 0.55:
191
+ return
192
+ cita = list(set([top_idx[i]
193
+ for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
194
+ for i in cita:
195
+ res += f"@?{i}?@"
196
 
197
  return cita
198
 
199
  punct = set(";。?!!")
200
+ if not self.qryr.isChinese(ans):
201
  punct.add("?")
202
  punct.add(".")
203
  while e < len(ans):
204
  if e - s < 12 or ans[e] not in punct:
205
  e += 1
206
  continue
207
+ if ans[e] == "." and e + \
208
+ 1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
209
  e += 1
210
  continue
211
+ if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
212
  e += 1
213
  continue
214
  res += ans[s: e]
 
217
  e += 1
218
  s = e
219
 
220
+ if s < len(ans):
221
  res += ans[s:]
222
  citeit()
223
 
224
  return res
225
 
226
+ def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
227
+ vfield="q_vec", cfield="content_ltks"):
228
+ ins_embd = [
229
+ Dealer.trans2floats(
230
+ sres.field[i]["q_vec"]) for i in sres.ids]
231
+ if not ins_embd:
232
+ return []
233
+ ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
234
+ # return CosineSimilarity([sres.query_vector], ins_embd)[0]
235
+ sim = self.qryr.hybrid_similarity(sres.query_vector,
236
+ ins_embd,
237
  huqie.qie(query).split(" "),
238
  ins_tw, tkweight, vtweight)
239
  return sim
240
 
241
 
242
+ if __name__ == "__main__":
 
243
  from util import es_conn
244
  SE = Dealer(es_conn.HuEs("infiniflow"))
245
  qs = [
246
  "胡凯",
247
  ""
248
  ]
249
+ for q in qs:
250
  print(">>>>>>>>>>>>>>>>>>>>", q)
251
+ print(SE.search(
252
+ {"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
python/parser/excel_parser.py CHANGED
@@ -5,8 +5,10 @@ from io import BytesIO
5
 
6
  class HuExcelParser:
7
  def __call__(self, fnm):
8
- if isinstance(fnm, str):wb = load_workbook(fnm)
9
- else: wb = load_workbook(BytesIO(fnm))
 
 
10
  res = []
11
  for sheetname in wb.sheetnames:
12
  ws = wb[sheetname]
 
5
 
6
  class HuExcelParser:
7
  def __call__(self, fnm):
8
+ if isinstance(fnm, str):
9
+ wb = load_workbook(fnm)
10
+ else:
11
+ wb = load_workbook(BytesIO(fnm))
12
  res = []
13
  for sheetname in wb.sheetnames:
14
  ws = wb[sheetname]
python/parser/pdf_parser.py CHANGED
@@ -53,7 +53,7 @@ class HuParser:
53
  def _y_dis(
54
  self, a, b):
55
  return (
56
- b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
57
 
58
  def _match_proj(self, b):
59
  proj_patt = [
@@ -76,9 +76,9 @@ class HuParser:
76
  tks_down = huqie.qie(down["text"][:LEN]).split(" ")
77
  tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
78
  tks_all = up["text"][-LEN:].strip() \
79
- + (" " if re.match(r"[a-zA-Z0-9]+",
80
- up["text"][-1] + down["text"][0]) else "") \
81
- + down["text"][:LEN].strip()
82
  tks_all = huqie.qie(tks_all).split(" ")
83
  fea = [
84
  up.get("R", -1) == down.get("R", -1),
@@ -100,7 +100,7 @@ class HuParser:
100
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
101
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
102
  True if re.search(r"[\((][^\))]+$", up["text"])
103
- and re.search(r"[\))]", down["text"]) else False,
104
  self._match_proj(down),
105
  True if re.match(r"[A-Z]", down["text"]) else False,
106
  True if re.match(r"[A-Z]", up["text"][-1]) else False,
@@ -217,7 +217,7 @@ class HuParser:
217
  assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
218
  tp, btm, x0, x1, b)
219
  ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
220
- x0 != 0 and btm - tp != 0 else 0
221
  if ov > 0 and ratio:
222
  ov /= (x1 - x0) * (btm - tp)
223
  return ov
@@ -382,7 +382,7 @@ class HuParser:
382
  continue
383
  for tb in tbls: # for table
384
  left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
385
- tb["x1"] + MARGIN, tb["bottom"] + MARGIN
386
  left *= ZM
387
  top *= ZM
388
  right *= ZM
@@ -899,7 +899,7 @@ class HuParser:
899
  lst_r = rows[-1]
900
  if lst_r[-1].get("R", "") != b.get("R", "") \
901
  or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
902
- ): # new row
903
  btm = b["bottom"]
904
  b["rn"] += 1
905
  rows.append([b])
@@ -949,9 +949,9 @@ class HuParser:
949
  j += 1
950
  continue
951
  f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
952
- [j - 1][0].get("text")) or j == 0
953
  ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
954
- [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
955
  if f and ff:
956
  j += 1
957
  continue
@@ -1012,9 +1012,9 @@ class HuParser:
1012
  i += 1
1013
  continue
1014
  f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
1015
- [jj][0].get("text")) or i == 0
1016
  ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
1017
- [jj][0].get("text")) or i + 1 >= len(tbl)
1018
  if f and ff:
1019
  i += 1
1020
  continue
@@ -1169,8 +1169,8 @@ class HuParser:
1169
  else "") + headers[j - 1][k]
1170
  else:
1171
  headers[j][k] = headers[j - 1][k] \
1172
- + ("的" if headers[j - 1][k] else "") \
1173
- + headers[j][k]
1174
 
1175
  logging.debug(
1176
  f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
@@ -1247,7 +1247,7 @@ class HuParser:
1247
  i += 1
1248
  continue
1249
  lout_no = str(self.boxes[i]["page_number"]) + \
1250
- "-" + str(self.boxes[i]["layoutno"])
1251
  if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
1252
  "figure caption", "reference"]:
1253
  nomerge_lout_no.append(lst_lout_no)
@@ -1526,7 +1526,8 @@ class HuParser:
1526
  return "\n\n".join(res)
1527
 
1528
  def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
1529
- self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
 
1530
  self.lefted_chars = []
1531
  self.mean_height = []
1532
  self.mean_width = []
@@ -1601,7 +1602,7 @@ class HuParser:
1601
  self.page_images[pns[0]].crop((left * ZM, top * ZM,
1602
  right *
1603
  ZM, min(
1604
- bottom, self.page_images[pns[0]].size[1])
1605
  ))
1606
  )
1607
  bottom -= self.page_images[pns[0]].size[1]
 
53
  def _y_dis(
54
  self, a, b):
55
  return (
56
+ b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
57
 
58
  def _match_proj(self, b):
59
  proj_patt = [
 
76
  tks_down = huqie.qie(down["text"][:LEN]).split(" ")
77
  tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
78
  tks_all = up["text"][-LEN:].strip() \
79
+ + (" " if re.match(r"[a-zA-Z0-9]+",
80
+ up["text"][-1] + down["text"][0]) else "") \
81
+ + down["text"][:LEN].strip()
82
  tks_all = huqie.qie(tks_all).split(" ")
83
  fea = [
84
  up.get("R", -1) == down.get("R", -1),
 
100
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
101
  True if re.search(r"[,,][^。.]+$", up["text"]) else False,
102
  True if re.search(r"[\((][^\))]+$", up["text"])
103
+ and re.search(r"[\))]", down["text"]) else False,
104
  self._match_proj(down),
105
  True if re.match(r"[A-Z]", down["text"]) else False,
106
  True if re.match(r"[A-Z]", up["text"][-1]) else False,
 
217
  assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
218
  tp, btm, x0, x1, b)
219
  ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
220
+ x0 != 0 and btm - tp != 0 else 0
221
  if ov > 0 and ratio:
222
  ov /= (x1 - x0) * (btm - tp)
223
  return ov
 
382
  continue
383
  for tb in tbls: # for table
384
  left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
385
+ tb["x1"] + MARGIN, tb["bottom"] + MARGIN
386
  left *= ZM
387
  top *= ZM
388
  right *= ZM
 
899
  lst_r = rows[-1]
900
  if lst_r[-1].get("R", "") != b.get("R", "") \
901
  or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
902
+ ): # new row
903
  btm = b["bottom"]
904
  b["rn"] += 1
905
  rows.append([b])
 
949
  j += 1
950
  continue
951
  f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
952
+ [j - 1][0].get("text")) or j == 0
953
  ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
954
+ [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
955
  if f and ff:
956
  j += 1
957
  continue
 
1012
  i += 1
1013
  continue
1014
  f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
1015
+ [jj][0].get("text")) or i == 0
1016
  ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
1017
+ [jj][0].get("text")) or i + 1 >= len(tbl)
1018
  if f and ff:
1019
  i += 1
1020
  continue
 
1169
  else "") + headers[j - 1][k]
1170
  else:
1171
  headers[j][k] = headers[j - 1][k] \
1172
+ + ("的" if headers[j - 1][k] else "") \
1173
+ + headers[j][k]
1174
 
1175
  logging.debug(
1176
  f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
 
1247
  i += 1
1248
  continue
1249
  lout_no = str(self.boxes[i]["page_number"]) + \
1250
+ "-" + str(self.boxes[i]["layoutno"])
1251
  if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
1252
  "figure caption", "reference"]:
1253
  nomerge_lout_no.append(lst_lout_no)
 
1526
  return "\n\n".join(res)
1527
 
1528
  def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
1529
+ self.pdf = pdfplumber.open(fnm) if isinstance(
1530
+ fnm, str) else pdfplumber.open(BytesIO(fnm))
1531
  self.lefted_chars = []
1532
  self.mean_height = []
1533
  self.mean_width = []
 
1602
  self.page_images[pns[0]].crop((left * ZM, top * ZM,
1603
  right *
1604
  ZM, min(
1605
+ bottom, self.page_images[pns[0]].size[1])
1606
  ))
1607
  )
1608
  bottom -= self.page_images[pns[0]].size[1]
python/svr/dialog_svr.py CHANGED
@@ -16,11 +16,12 @@ from io import BytesIO
16
  from util import config
17
  from timeit import default_timer as timer
18
  from collections import OrderedDict
 
19
 
20
  SE = None
21
  CFIELD="content_ltks"
22
- EMBEDDING = HuEmbedding()
23
- LLM = GptTurbo()
24
 
25
  def get_QA_pairs(hists):
26
  pa = []
 
16
  from util import config
17
  from timeit import default_timer as timer
18
  from collections import OrderedDict
19
+ from llm import ChatModel, EmbeddingModel
20
 
21
  SE = None
22
  CFIELD="content_ltks"
23
+ EMBEDDING = EmbeddingModel
24
+ LLM = ChatModel
25
 
26
  def get_QA_pairs(hists):
27
  pa = []
python/svr/parse_user_docs.py CHANGED
@@ -1,4 +1,4 @@
1
- import json, os, sys, hashlib, copy, time, random, re, logging, torch
2
  from os.path import dirname, realpath
3
  sys.path.append(dirname(realpath(__file__)) + "/../")
4
  from util.es_conn import HuEs
@@ -7,10 +7,10 @@ from util.minio_conn import HuMinio
7
  from util import rmSpace, findMaxDt
8
  from FlagEmbedding import FlagModel
9
  from nlp import huchunk, huqie, search
10
- import base64, hashlib
11
  from io import BytesIO
12
  import pandas as pd
13
  from elasticsearch_dsl import Q
 
14
  from parser import (
15
  PdfParser,
16
  DocxParser,
@@ -40,6 +40,15 @@ def chuck_doc(name, binary):
40
  if suff.find("doc") >= 0: return DOC(binary)
41
  if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
42
  if suff.find("ppt") >= 0: return PPT(binary)
 
 
 
 
 
 
 
 
 
43
 
44
  return TextChunker()(binary)
45
 
@@ -119,7 +128,6 @@ def build(row):
119
  set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
120
  return []
121
 
122
- print(row["doc_name"], obj)
123
  if not obj.text_chunks and not obj.table_chunks:
124
  set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
125
  return []
@@ -146,7 +154,10 @@ def build(row):
146
  if not img:
147
  docs.append(d)
148
  continue
149
- img.save(output_buffer, format='JPEG')
 
 
 
150
  MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
151
  output_buffer.getvalue())
152
  d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
 
1
+ import json, os, sys, hashlib, copy, time, random, re
2
  from os.path import dirname, realpath
3
  sys.path.append(dirname(realpath(__file__)) + "/../")
4
  from util.es_conn import HuEs
 
7
  from util import rmSpace, findMaxDt
8
  from FlagEmbedding import FlagModel
9
  from nlp import huchunk, huqie, search
 
10
  from io import BytesIO
11
  import pandas as pd
12
  from elasticsearch_dsl import Q
13
+ from PIL import Image
14
  from parser import (
15
  PdfParser,
16
  DocxParser,
 
40
  if suff.find("doc") >= 0: return DOC(binary)
41
  if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
42
  if suff.find("ppt") >= 0: return PPT(binary)
43
+ if os.envirement.get("PARSE_IMAGE") \
44
+ and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
45
+ name.lower()):
46
+ from llm import CvModel
47
+ txt = CvModel.describe(binary)
48
+ field = TextChunker.Fields()
49
+ field.text_chunks = [(txt, binary)]
50
+ field.table_chunks = []
51
+
52
 
53
  return TextChunker()(binary)
54
 
 
128
  set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
129
  return []
130
 
 
131
  if not obj.text_chunks and not obj.table_chunks:
132
  set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
133
  return []
 
154
  if not img:
155
  docs.append(d)
156
  continue
157
+
158
+ if isinstance(img, Image): img.save(output_buffer, format='JPEG')
159
+ else: output_buffer = BytesIO(img)
160
+
161
  MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
162
  output_buffer.getvalue())
163
  d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
python/util/__init__.py CHANGED
@@ -1,19 +1,24 @@
1
  import re
2
 
 
3
  def rmSpace(txt):
4
  txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
5
  return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
6
 
 
7
  def findMaxDt(fnm):
8
  m = "1970-01-01 00:00:00"
9
  try:
10
  with open(fnm, "r") as f:
11
  while True:
12
  l = f.readline()
13
- if not l:break
 
14
  l = l.strip("\n")
15
- if l == 'nan':continue
16
- if l > m:m = l
 
 
17
  except Exception as e:
18
- print("WARNING: can't find "+ fnm)
19
  return m
 
1
  import re
2
 
3
+
4
  def rmSpace(txt):
5
  txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
6
  return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
7
 
8
+
9
  def findMaxDt(fnm):
10
  m = "1970-01-01 00:00:00"
11
  try:
12
  with open(fnm, "r") as f:
13
  while True:
14
  l = f.readline()
15
+ if not l:
16
+ break
17
  l = l.strip("\n")
18
+ if l == 'nan':
19
+ continue
20
+ if l > m:
21
+ m = l
22
  except Exception as e:
23
+ print("WARNING: can't find " + fnm)
24
  return m
python/util/config.py CHANGED
@@ -1,25 +1,31 @@
1
- from configparser import ConfigParser
2
- import os,inspect
 
3
 
4
  CF = ConfigParser()
5
  __fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
6
- if not os.path.exists(__fnm):__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
7
- assert os.path.exists(__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
8
- if not os.path.exists(__fnm): __fnm = "./sys.cnf"
 
 
 
9
 
10
  CF.read(__fnm)
11
 
 
12
  class Config:
13
  def __init__(self, env):
14
  self.env = env
15
- if env == "spark":CF.read("./cv.cnf")
 
16
 
17
  def get(self, key, default=None):
18
  global CF
19
- return os.environ.get(key.upper(), \
20
- CF[self.env].get(key, default)
21
- )
 
22
 
23
  def init(env):
24
  return Config(env)
25
-
 
1
+ from configparser import ConfigParser
2
+ import os
3
+ import inspect
4
 
5
  CF = ConfigParser()
6
  __fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
7
+ if not os.path.exists(__fnm):
8
+ __fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
9
+ assert os.path.exists(
10
+ __fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
11
+ if not os.path.exists(__fnm):
12
+ __fnm = "./sys.cnf"
13
 
14
  CF.read(__fnm)
15
 
16
+
17
  class Config:
18
  def __init__(self, env):
19
  self.env = env
20
+ if env == "spark":
21
+ CF.read("./cv.cnf")
22
 
23
  def get(self, key, default=None):
24
  global CF
25
+ return os.environ.get(key.upper(),
26
+ CF[self.env].get(key, default)
27
+ )
28
+
29
 
30
  def init(env):
31
  return Config(env)
 
python/util/db_conn.py CHANGED
@@ -3,6 +3,7 @@ import time
3
  from util import config
4
  import pandas as pd
5
 
 
6
  class Postgres(object):
7
  def __init__(self, env, dbnm):
8
  self.config = config.init(env)
@@ -13,36 +14,42 @@ class Postgres(object):
13
  def __open__(self):
14
  import psycopg2
15
  try:
16
- if self.conn:self.__close__()
 
17
  del self.conn
18
  except Exception as e:
19
  pass
20
 
21
  try:
22
- self.conn = psycopg2.connect(f"dbname={self.dbnm} user={self.config.get('pgdb_usr')} password={self.config.get('pgdb_pwd')} host={self.config.get('pgdb_host')} port={self.config.get('pgdb_port')}")
 
 
 
 
23
  except Exception as e:
24
- logging.error("Fail to connect %s "%self.config.get("pgdb_host") + str(e))
25
-
 
26
 
27
  def __close__(self):
28
  try:
29
  self.conn.close()
30
  except Exception as e:
31
- logging.error("Fail to close %s "%self.config.get("pgdb_host") + str(e))
32
-
 
33
 
34
  def select(self, sql):
35
  for _ in range(10):
36
  try:
37
  return pd.read_sql(sql, self.conn)
38
  except Exception as e:
39
- logging.error(f"Fail to exec {sql} "+str(e))
40
  self.__open__()
41
  time.sleep(1)
42
 
43
  return pd.DataFrame()
44
 
45
-
46
  def update(self, sql):
47
  for _ in range(10):
48
  try:
@@ -53,11 +60,11 @@ class Postgres(object):
53
  cur.close()
54
  return updated_rows
55
  except Exception as e:
56
- logging.error(f"Fail to exec {sql} "+str(e))
57
  self.__open__()
58
  time.sleep(1)
59
  return 0
60
 
 
61
  if __name__ == "__main__":
62
  Postgres("infiniflow", "docgpt")
63
-
 
3
  from util import config
4
  import pandas as pd
5
 
6
+
7
  class Postgres(object):
8
  def __init__(self, env, dbnm):
9
  self.config = config.init(env)
 
14
  def __open__(self):
15
  import psycopg2
16
  try:
17
+ if self.conn:
18
+ self.__close__()
19
  del self.conn
20
  except Exception as e:
21
  pass
22
 
23
  try:
24
+ self.conn = psycopg2.connect(f"""dbname={self.dbnm}
25
+ user={self.config.get('postgres_user')}
26
+ password={self.config.get('postgres_password')}
27
+ host={self.config.get('postgres_host')}
28
+ port={self.config.get('postgres_port')}""")
29
  except Exception as e:
30
+ logging.error(
31
+ "Fail to connect %s " %
32
+ self.config.get("pgdb_host") + str(e))
33
 
34
  def __close__(self):
35
  try:
36
  self.conn.close()
37
  except Exception as e:
38
+ logging.error(
39
+ "Fail to close %s " %
40
+ self.config.get("pgdb_host") + str(e))
41
 
42
  def select(self, sql):
43
  for _ in range(10):
44
  try:
45
  return pd.read_sql(sql, self.conn)
46
  except Exception as e:
47
+ logging.error(f"Fail to exec {sql} " + str(e))
48
  self.__open__()
49
  time.sleep(1)
50
 
51
  return pd.DataFrame()
52
 
 
53
  def update(self, sql):
54
  for _ in range(10):
55
  try:
 
60
  cur.close()
61
  return updated_rows
62
  except Exception as e:
63
+ logging.error(f"Fail to exec {sql} " + str(e))
64
  self.__open__()
65
  time.sleep(1)
66
  return 0
67
 
68
+
69
  if __name__ == "__main__":
70
  Postgres("infiniflow", "docgpt")
 
python/util/es_conn.py CHANGED
@@ -228,7 +228,8 @@ class HuEs:
228
  return False
229
 
230
  def search(self, q, idxnm=None, src=False, timeout="2s"):
231
- if not isinstance(q, dict): q = Search().query(q).to_dict()
 
232
  for i in range(3):
233
  try:
234
  res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
@@ -274,9 +275,10 @@ class HuEs:
274
 
275
  return False
276
 
277
-
278
  def updateScriptByQuery(self, q, scripts, idxnm=None):
279
- ubq = UpdateByQuery(index=self.idxnm if not idxnm else idxnm).using(self.es).query(q)
 
 
280
  ubq = ubq.script(source=scripts)
281
  ubq = ubq.params(refresh=True)
282
  ubq = ubq.params(slices=5)
@@ -294,7 +296,6 @@ class HuEs:
294
 
295
  return False
296
 
297
-
298
  def deleteByQuery(self, query, idxnm=""):
299
  for i in range(3):
300
  try:
@@ -392,7 +393,7 @@ class HuEs:
392
  return rr
393
 
394
  def scrollIter(self, pagesize=100, scroll_time='2m', q={
395
- "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
396
  for _ in range(100):
397
  try:
398
  page = self.es.search(
 
228
  return False
229
 
230
  def search(self, q, idxnm=None, src=False, timeout="2s"):
231
+ if not isinstance(q, dict):
232
+ q = Search().query(q).to_dict()
233
  for i in range(3):
234
  try:
235
  res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
 
275
 
276
  return False
277
 
 
278
  def updateScriptByQuery(self, q, scripts, idxnm=None):
279
+ ubq = UpdateByQuery(
280
+ index=self.idxnm if not idxnm else idxnm).using(
281
+ self.es).query(q)
282
  ubq = ubq.script(source=scripts)
283
  ubq = ubq.params(refresh=True)
284
  ubq = ubq.params(slices=5)
 
296
 
297
  return False
298
 
 
299
  def deleteByQuery(self, query, idxnm=""):
300
  for i in range(3):
301
  try:
 
393
  return rr
394
 
395
  def scrollIter(self, pagesize=100, scroll_time='2m', q={
396
+ "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
397
  for _ in range(100):
398
  try:
399
  page = self.es.search(
python/util/minio_conn.py CHANGED
@@ -4,6 +4,7 @@ from util import config
4
  from minio import Minio
5
  from io import BytesIO
6
 
 
7
  class HuMinio(object):
8
  def __init__(self, env):
9
  self.config = config.init(env)
@@ -12,64 +13,62 @@ class HuMinio(object):
12
 
13
  def __open__(self):
14
  try:
15
- if self.conn:self.__close__()
 
16
  except Exception as e:
17
  pass
18
 
19
  try:
20
  self.conn = Minio(self.config.get("minio_host"),
21
- access_key=self.config.get("minio_usr"),
22
- secret_key=self.config.get("minio_pwd"),
23
  secure=False
24
- )
25
  except Exception as e:
26
- logging.error("Fail to connect %s "%self.config.get("minio_host") + str(e))
27
-
 
28
 
29
  def __close__(self):
30
  del self.conn
31
  self.conn = None
32
 
33
-
34
  def put(self, bucket, fnm, binary):
35
  for _ in range(10):
36
  try:
37
  if not self.conn.bucket_exists(bucket):
38
  self.conn.make_bucket(bucket)
39
 
40
- r = self.conn.put_object(bucket, fnm,
41
  BytesIO(binary),
42
  len(binary)
43
- )
44
  return r
45
  except Exception as e:
46
- logging.error(f"Fail put {bucket}/{fnm}: "+str(e))
47
  self.__open__()
48
  time.sleep(1)
49
 
50
-
51
  def get(self, bucket, fnm):
52
  for _ in range(10):
53
  try:
54
  r = self.conn.get_object(bucket, fnm)
55
  return r.read()
56
  except Exception as e:
57
- logging.error(f"fail get {bucket}/{fnm}: "+str(e))
58
  self.__open__()
59
  time.sleep(1)
60
- return
61
-
62
 
63
  def get_presigned_url(self, bucket, fnm, expires):
64
  for _ in range(10):
65
  try:
66
  return self.conn.get_presigned_url("GET", bucket, fnm, expires)
67
  except Exception as e:
68
- logging.error(f"fail get {bucket}/{fnm}: "+str(e))
69
  self.__open__()
70
  time.sleep(1)
71
- return
72
-
73
 
74
 
75
  if __name__ == "__main__":
@@ -78,9 +77,8 @@ if __name__ == "__main__":
78
  from PIL import Image
79
  img = Image.open(fnm)
80
  buff = BytesIO()
81
- img.save(buff, format='JPEG')
82
  print(conn.put("test", "11-408.jpg", buff.getvalue()))
83
  bts = conn.get("test", "11-408.jpg")
84
  img = Image.open(BytesIO(bts))
85
  img.save("test.jpg")
86
-
 
4
  from minio import Minio
5
  from io import BytesIO
6
 
7
+
8
  class HuMinio(object):
9
  def __init__(self, env):
10
  self.config = config.init(env)
 
13
 
14
  def __open__(self):
15
  try:
16
+ if self.conn:
17
+ self.__close__()
18
  except Exception as e:
19
  pass
20
 
21
  try:
22
  self.conn = Minio(self.config.get("minio_host"),
23
+ access_key=self.config.get("minio_user"),
24
+ secret_key=self.config.get("minio_password"),
25
  secure=False
26
+ )
27
  except Exception as e:
28
+ logging.error(
29
+ "Fail to connect %s " %
30
+ self.config.get("minio_host") + str(e))
31
 
32
  def __close__(self):
33
  del self.conn
34
  self.conn = None
35
 
 
36
  def put(self, bucket, fnm, binary):
37
  for _ in range(10):
38
  try:
39
  if not self.conn.bucket_exists(bucket):
40
  self.conn.make_bucket(bucket)
41
 
42
+ r = self.conn.put_object(bucket, fnm,
43
  BytesIO(binary),
44
  len(binary)
45
+ )
46
  return r
47
  except Exception as e:
48
+ logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
49
  self.__open__()
50
  time.sleep(1)
51
 
 
52
  def get(self, bucket, fnm):
53
  for _ in range(10):
54
  try:
55
  r = self.conn.get_object(bucket, fnm)
56
  return r.read()
57
  except Exception as e:
58
+ logging.error(f"fail get {bucket}/{fnm}: " + str(e))
59
  self.__open__()
60
  time.sleep(1)
61
+ return
 
62
 
63
  def get_presigned_url(self, bucket, fnm, expires):
64
  for _ in range(10):
65
  try:
66
  return self.conn.get_presigned_url("GET", bucket, fnm, expires)
67
  except Exception as e:
68
+ logging.error(f"fail get {bucket}/{fnm}: " + str(e))
69
  self.__open__()
70
  time.sleep(1)
71
+ return
 
72
 
73
 
74
  if __name__ == "__main__":
 
77
  from PIL import Image
78
  img = Image.open(fnm)
79
  buff = BytesIO()
80
+ img.save(buff, format='JPEG')
81
  print(conn.put("test", "11-408.jpg", buff.getvalue()))
82
  bts = conn.get("test", "11-408.jpg")
83
  img = Image.open(BytesIO(bts))
84
  img.save("test.jpg")