KevinHuSh commited on
Commit
e0e6518
·
1 Parent(s): 01b9866

add new model gpt-3-turbo (#352)

Browse files

### What problem does this PR solve?


Issue link:#351

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/conversation_app.py CHANGED
@@ -193,14 +193,14 @@ def chat(dialog, messages, **kwargs):
193
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
194
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
195
 
 
196
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
197
  # try to use sql if field mapping is good to go
198
  if field_map:
199
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
200
- ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
201
  if ans: return ans
202
 
203
- prompt_config = dialog.prompt_config
204
  for p in prompt_config["parameters"]:
205
  if p["key"] == "knowledge":
206
  continue
@@ -255,6 +255,7 @@ def chat(dialog, messages, **kwargs):
255
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
256
  if not recall_docs: recall_docs = kbinfos["doc_aggs"]
257
  kbinfos["doc_aggs"] = recall_docs
 
258
  for c in kbinfos["chunks"]:
259
  if c.get("vector"):
260
  del c["vector"]
@@ -263,7 +264,7 @@ def chat(dialog, messages, **kwargs):
263
  return {"answer": answer, "reference": kbinfos}
264
 
265
 
266
- def use_sql(question, field_map, tenant_id, chat_mdl):
267
  sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
268
  user_promt = """
269
  表名:{};
@@ -353,12 +354,16 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
353
  # compose markdown table
354
  clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
355
  tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
 
356
  line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
357
  ("|------|" if docid_idx and docid_idx else "")
 
358
  rows = ["|" +
359
  "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
360
  "|" for r in tbl["rows"]]
361
- rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
 
 
362
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
363
 
364
  if not docid_idx or not docnm_idx:
 
193
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
194
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
195
 
196
+ prompt_config = dialog.prompt_config
197
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
198
  # try to use sql if field mapping is good to go
199
  if field_map:
200
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
201
+ ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
202
  if ans: return ans
203
 
 
204
  for p in prompt_config["parameters"]:
205
  if p["key"] == "knowledge":
206
  continue
 
255
  d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
256
  if not recall_docs: recall_docs = kbinfos["doc_aggs"]
257
  kbinfos["doc_aggs"] = recall_docs
258
+
259
  for c in kbinfos["chunks"]:
260
  if c.get("vector"):
261
  del c["vector"]
 
264
  return {"answer": answer, "reference": kbinfos}
265
 
266
 
267
+ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
268
  sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
269
  user_promt = """
270
  表名:{};
 
354
  # compose markdown table
355
  clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
356
  tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
357
+
358
  line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
359
  ("|------|" if docid_idx and docid_idx else "")
360
+
361
  rows = ["|" +
362
  "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
363
  "|" for r in tbl["rows"]]
364
+ if quota:
365
+ rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
366
+ else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
367
  rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
368
 
369
  if not docid_idx or not docnm_idx:
api/db/init_data.py CHANGED
@@ -159,6 +159,12 @@ def init_llm_factory():
159
  "max_tokens": 8191,
160
  "model_type": LLMType.CHAT.value
161
  }, {
 
 
 
 
 
 
162
  "fid": factory_infos[0]["name"],
163
  "llm_name": "gpt-4-32k",
164
  "tags": "LLM,CHAT,32K",
 
159
  "max_tokens": 8191,
160
  "model_type": LLMType.CHAT.value
161
  }, {
162
+ "fid": factory_infos[0]["name"],
163
+ "llm_name": "gpt-4-turbo",
164
+ "tags": "LLM,CHAT,8K",
165
+ "max_tokens": 8191,
166
+ "model_type": LLMType.CHAT.value
167
+ },{
168
  "fid": factory_infos[0]["name"],
169
  "llm_name": "gpt-4-32k",
170
  "tags": "LLM,CHAT,32K",