chongcb chongchuanbing Kevin Hu commited on
Commit
8de8827
1 Parent(s): 13bc594

Feature/feat1017 (#2872)

Browse files

### What problem does this PR solve?

1. fix: mid map show error in knowledge graph, juse because
```@antv/g6```version changed
2. feat: concurrent threads configuration support in graph extractor
3. fix: used tokens update failed for tenant
4. feat: timeout configuration support for llm
5. fix: regex error in graph extractor
6. feat: qwen rerank(```gte-rerank```) support
7. fix: timeout deal in knowledge graph index process. Now chat by
stream output, also, it is configuratable.
8. feat: ```qwen-long``` model configuration

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: chongchuanbing <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>

api/db/services/llm_service.py CHANGED
@@ -167,11 +167,13 @@ class TenantLLMService(CommonService):
167
  else:
168
  assert False, "LLM type error"
169
 
 
 
170
  num = 0
171
  try:
172
- for u in cls.query(tenant_id=tenant_id, llm_name=mdlnm):
173
  num += cls.model.update(used_tokens=u.used_tokens + used_tokens)\
174
- .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
175
  .execute()
176
  except Exception as e:
177
  pass
@@ -207,7 +209,7 @@ class LLMBundle(object):
207
  if not TenantLLMService.increase_usage(
208
  self.tenant_id, self.llm_type, used_tokens):
209
  database_logger.error(
210
- "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
211
  return emd, used_tokens
212
 
213
  def encode_queries(self, query: str):
@@ -215,7 +217,7 @@ class LLMBundle(object):
215
  if not TenantLLMService.increase_usage(
216
  self.tenant_id, self.llm_type, used_tokens):
217
  database_logger.error(
218
- "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
219
  return emd, used_tokens
220
 
221
  def similarity(self, query: str, texts: list):
@@ -223,7 +225,7 @@ class LLMBundle(object):
223
  if not TenantLLMService.increase_usage(
224
  self.tenant_id, self.llm_type, used_tokens):
225
  database_logger.error(
226
- "Can't update token usage for {}/RERANK".format(self.tenant_id))
227
  return sim, used_tokens
228
 
229
  def describe(self, image, max_tokens=300):
@@ -231,7 +233,7 @@ class LLMBundle(object):
231
  if not TenantLLMService.increase_usage(
232
  self.tenant_id, self.llm_type, used_tokens):
233
  database_logger.error(
234
- "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
235
  return txt
236
 
237
  def transcription(self, audio):
@@ -239,7 +241,7 @@ class LLMBundle(object):
239
  if not TenantLLMService.increase_usage(
240
  self.tenant_id, self.llm_type, used_tokens):
241
  database_logger.error(
242
- "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id))
243
  return txt
244
 
245
  def tts(self, text):
@@ -254,10 +256,10 @@ class LLMBundle(object):
254
 
255
  def chat(self, system, history, gen_conf):
256
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
257
- if not TenantLLMService.increase_usage(
258
  self.tenant_id, self.llm_type, used_tokens, self.llm_name):
259
  database_logger.error(
260
- "Can't update token usage for {}/CHAT".format(self.tenant_id))
261
  return txt
262
 
263
  def chat_streamly(self, system, history, gen_conf):
@@ -266,6 +268,6 @@ class LLMBundle(object):
266
  if not TenantLLMService.increase_usage(
267
  self.tenant_id, self.llm_type, txt, self.llm_name):
268
  database_logger.error(
269
- "Can't update token usage for {}/CHAT".format(self.tenant_id))
270
  return
271
  yield txt
 
167
  else:
168
  assert False, "LLM type error"
169
 
170
+ llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
171
+
172
  num = 0
173
  try:
174
+ for u in cls.query(tenant_id=tenant_id, llm_name=llm_name):
175
  num += cls.model.update(used_tokens=u.used_tokens + used_tokens)\
176
+ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name)\
177
  .execute()
178
  except Exception as e:
179
  pass
 
209
  if not TenantLLMService.increase_usage(
210
  self.tenant_id, self.llm_type, used_tokens):
211
  database_logger.error(
212
+ "Can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
213
  return emd, used_tokens
214
 
215
  def encode_queries(self, query: str):
 
217
  if not TenantLLMService.increase_usage(
218
  self.tenant_id, self.llm_type, used_tokens):
219
  database_logger.error(
220
+ "Can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
221
  return emd, used_tokens
222
 
223
  def similarity(self, query: str, texts: list):
 
225
  if not TenantLLMService.increase_usage(
226
  self.tenant_id, self.llm_type, used_tokens):
227
  database_logger.error(
228
+ "Can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
229
  return sim, used_tokens
230
 
231
  def describe(self, image, max_tokens=300):
 
233
  if not TenantLLMService.increase_usage(
234
  self.tenant_id, self.llm_type, used_tokens):
235
  database_logger.error(
236
+ "Can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
237
  return txt
238
 
239
  def transcription(self, audio):
 
241
  if not TenantLLMService.increase_usage(
242
  self.tenant_id, self.llm_type, used_tokens):
243
  database_logger.error(
244
+ "Can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
245
  return txt
246
 
247
  def tts(self, text):
 
256
 
257
  def chat(self, system, history, gen_conf):
258
  txt, used_tokens = self.mdl.chat(system, history, gen_conf)
259
+ if isinstance(txt, int) and not TenantLLMService.increase_usage(
260
  self.tenant_id, self.llm_type, used_tokens, self.llm_name):
261
  database_logger.error(
262
+ "Can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
263
  return txt
264
 
265
  def chat_streamly(self, system, history, gen_conf):
 
268
  if not TenantLLMService.increase_usage(
269
  self.tenant_id, self.llm_type, txt, self.llm_name):
270
  database_logger.error(
271
+ "Can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
272
  return
273
  yield txt
conf/llm_factories.json CHANGED
@@ -89,9 +89,15 @@
89
  {
90
  "name": "Tongyi-Qianwen",
91
  "logo": "",
92
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
93
  "status": "1",
94
  "llm": [
 
 
 
 
 
 
95
  {
96
  "llm_name": "qwen-turbo",
97
  "tags": "LLM,CHAT,8K",
@@ -139,6 +145,12 @@
139
  "tags": "LLM,CHAT,IMAGE2TEXT",
140
  "max_tokens": 765,
141
  "model_type": "image2text"
 
 
 
 
 
 
142
  }
143
  ]
144
  },
 
89
  {
90
  "name": "Tongyi-Qianwen",
91
  "logo": "",
92
+ "tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
93
  "status": "1",
94
  "llm": [
95
+ {
96
+ "llm_name": "qwen-long",
97
+ "tags": "LLM,CHAT,10000K",
98
+ "max_tokens": 1000000,
99
+ "model_type": "chat"
100
+ },
101
  {
102
  "llm_name": "qwen-turbo",
103
  "tags": "LLM,CHAT,8K",
 
145
  "tags": "LLM,CHAT,IMAGE2TEXT",
146
  "max_tokens": 765,
147
  "model_type": "image2text"
148
+ },
149
+ {
150
+ "llm_name": "gte-rerank",
151
+ "tags": "RE-RANK,4k",
152
+ "max_tokens": 4000,
153
+ "model_type": "rerank"
154
  }
155
  ]
156
  },
graphrag/graph_extractor.py CHANGED
@@ -164,6 +164,7 @@ class GraphExtractor:
164
  text = perform_variable_replacements(self._extraction_prompt, variables=variables)
165
  gen_conf = {"temperature": 0.3}
166
  response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
 
167
  token_count = num_tokens_from_string(text + response)
168
 
169
  results = response or ""
 
164
  text = perform_variable_replacements(self._extraction_prompt, variables=variables)
165
  gen_conf = {"temperature": 0.3}
166
  response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
167
+ if response.find("**ERROR**") >= 0: raise Exception(response)
168
  token_count = num_tokens_from_string(text + response)
169
 
170
  results = response or ""
graphrag/index.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  from concurrent.futures import ThreadPoolExecutor
17
  import json
18
  from functools import reduce
@@ -64,7 +65,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, en
64
  texts, graphs = [], []
65
  cnt = 0
66
  threads = []
67
- exe = ThreadPoolExecutor(max_workers=50)
 
68
  for i in range(len(chunks)):
69
  tkn_cnt = num_tokens_from_string(chunks[i])
70
  if cnt+tkn_cnt >= left_token_count and texts:
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import os
17
  from concurrent.futures import ThreadPoolExecutor
18
  import json
19
  from functools import reduce
 
65
  texts, graphs = [], []
66
  cnt = 0
67
  threads = []
68
+ max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
69
+ exe = ThreadPoolExecutor(max_workers=max_workers)
70
  for i in range(len(chunks)):
71
  tkn_cnt = num_tokens_from_string(chunks[i])
72
  if cnt+tkn_cnt >= left_token_count and texts:
graphrag/mind_map_extractor.py CHANGED
@@ -16,6 +16,7 @@
16
 
17
  import collections
18
  import logging
 
19
  import re
20
  import logging
21
  import traceback
@@ -89,7 +90,8 @@ class MindMapExtractor:
89
  prompt_variables = {}
90
 
91
  try:
92
- exe = ThreadPoolExecutor(max_workers=12)
 
93
  threads = []
94
  token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512)
95
  texts = []
 
16
 
17
  import collections
18
  import logging
19
+ import os
20
  import re
21
  import logging
22
  import traceback
 
90
  prompt_variables = {}
91
 
92
  try:
93
+ max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
94
+ exe = ThreadPoolExecutor(max_workers=max_workers)
95
  threads = []
96
  token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512)
97
  texts = []
rag/llm/__init__.py CHANGED
@@ -122,7 +122,8 @@ RerankModel = {
122
  "TogetherAI": TogetherAIRerank,
123
  "SILICONFLOW": SILICONFLOWRerank,
124
  "BaiduYiyan": BaiduYiyanRerank,
125
- "Voyage AI": VoyageRerank
 
126
  }
127
 
128
  Seq2txtModel = {
 
122
  "TogetherAI": TogetherAIRerank,
123
  "SILICONFLOW": SILICONFLOWRerank,
124
  "BaiduYiyan": BaiduYiyanRerank,
125
+ "Voyage AI": VoyageRerank,
126
+ "Tongyi-Qianwen": QWenRerank,
127
  }
128
 
129
  Seq2txtModel = {
rag/llm/chat_model.py CHANGED
@@ -31,7 +31,8 @@ import asyncio
31
 
32
  class Base(ABC):
33
  def __init__(self, key, model_name, base_url):
34
- self.client = OpenAI(api_key=key, base_url=base_url)
 
35
  self.model_name = model_name
36
 
37
  def chat(self, system, history, gen_conf):
@@ -216,28 +217,39 @@ class QWenChat(Base):
216
  self.model_name = model_name
217
 
218
  def chat(self, system, history, gen_conf):
219
- from http import HTTPStatus
220
- if system:
221
- history.insert(0, {"role": "system", "content": system})
222
- response = Generation.call(
223
- self.model_name,
224
- messages=history,
225
- result_format='message',
226
- **gen_conf
227
- )
228
- ans = ""
229
- tk_count = 0
230
- if response.status_code == HTTPStatus.OK:
231
- ans += response.output.choices[0]['message']['content']
232
- tk_count += response.usage.total_tokens
233
- if response.output.choices[0].get("finish_reason", "") == "length":
234
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
235
- [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
236
- return ans, tk_count
237
 
238
- return "**ERROR**: " + response.message, tk_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- def chat_streamly(self, system, history, gen_conf):
 
 
 
 
 
 
 
 
 
 
241
  from http import HTTPStatus
242
  if system:
243
  history.insert(0, {"role": "system", "content": system})
@@ -249,6 +261,7 @@ class QWenChat(Base):
249
  messages=history,
250
  result_format='message',
251
  stream=True,
 
252
  **gen_conf
253
  )
254
  for resp in response:
@@ -267,6 +280,9 @@ class QWenChat(Base):
267
 
268
  yield tk_count
269
 
 
 
 
270
 
271
  class ZhipuChat(Base):
272
  def __init__(self, key, model_name="glm-3-turbo", **kwargs):
 
31
 
32
  class Base(ABC):
33
  def __init__(self, key, model_name, base_url):
34
+ timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
35
+ self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
36
  self.model_name = model_name
37
 
38
  def chat(self, system, history, gen_conf):
 
217
  self.model_name = model_name
218
 
219
  def chat(self, system, history, gen_conf):
220
+ stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true'
221
+ if not stream_flag:
222
+ from http import HTTPStatus
223
+ if system:
224
+ history.insert(0, {"role": "system", "content": system})
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ response = Generation.call(
227
+ self.model_name,
228
+ messages=history,
229
+ result_format='message',
230
+ **gen_conf
231
+ )
232
+ ans = ""
233
+ tk_count = 0
234
+ if response.status_code == HTTPStatus.OK:
235
+ ans += response.output.choices[0]['message']['content']
236
+ tk_count += response.usage.total_tokens
237
+ if response.output.choices[0].get("finish_reason", "") == "length":
238
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
239
+ [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
240
+ return ans, tk_count
241
 
242
+ return "**ERROR**: " + response.message, tk_count
243
+ else:
244
+ g = self._chat_streamly(system, history, gen_conf, incremental_output=True)
245
+ result_list = list(g)
246
+ error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0]
247
+ if len(error_msg_list) > 0:
248
+ return "**ERROR**: " + "".join(error_msg_list) , 0
249
+ else:
250
+ return "".join(result_list[:-1]), result_list[-1]
251
+
252
+ def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
253
  from http import HTTPStatus
254
  if system:
255
  history.insert(0, {"role": "system", "content": system})
 
261
  messages=history,
262
  result_format='message',
263
  stream=True,
264
+ incremental_output=incremental_output,
265
  **gen_conf
266
  )
267
  for resp in response:
 
280
 
281
  yield tk_count
282
 
283
+ def chat_streamly(self, system, history, gen_conf):
284
+ return self._chat_streamly(system, history, gen_conf)
285
+
286
 
287
  class ZhipuChat(Base):
288
  def __init__(self, key, model_name="glm-3-turbo", **kwargs):
rag/llm/rerank_model.py CHANGED
@@ -390,3 +390,27 @@ class VoyageRerank(Base):
390
  for r in res.results:
391
  rank[r.index] = r.relevance_score
392
  return rank, res.total_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  for r in res.results:
391
  rank[r.index] = r.relevance_score
392
  return rank, res.total_tokens
393
+
394
+ class QWenRerank(Base):
395
+ def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
396
+ import dashscope
397
+ self.api_key = key
398
+ self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
399
+
400
+ def similarity(self, query: str, texts: list):
401
+ import dashscope
402
+ from http import HTTPStatus
403
+ resp = dashscope.TextReRank.call(
404
+ api_key=self.api_key,
405
+ model=self.model_name,
406
+ query=query,
407
+ documents=texts,
408
+ top_n=len(texts),
409
+ return_documents=False
410
+ )
411
+ rank = np.zeros(len(texts), dtype=float)
412
+ if resp.status_code == HTTPStatus.OK:
413
+ for r in resp.output.results:
414
+ rank[r.index] = r.relevance_score
415
+ return rank, resp.usage.total_tokens
416
+ return rank, 0