KevinHuSh
commited on
Commit
·
e0e6518
1
Parent(s):
01b9866
add new model gpt-3-turbo (#352)
Browse files### What problem does this PR solve?
Issue link:#351
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/conversation_app.py +9 -4
- api/db/init_data.py +6 -0
api/apps/conversation_app.py
CHANGED
@@ -193,14 +193,14 @@ def chat(dialog, messages, **kwargs):
|
|
193 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
194 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
195 |
|
|
|
196 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
197 |
# try to use sql if field mapping is good to go
|
198 |
if field_map:
|
199 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
200 |
-
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
201 |
if ans: return ans
|
202 |
|
203 |
-
prompt_config = dialog.prompt_config
|
204 |
for p in prompt_config["parameters"]:
|
205 |
if p["key"] == "knowledge":
|
206 |
continue
|
@@ -255,6 +255,7 @@ def chat(dialog, messages, **kwargs):
|
|
255 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
256 |
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
257 |
kbinfos["doc_aggs"] = recall_docs
|
|
|
258 |
for c in kbinfos["chunks"]:
|
259 |
if c.get("vector"):
|
260 |
del c["vector"]
|
@@ -263,7 +264,7 @@ def chat(dialog, messages, **kwargs):
|
|
263 |
return {"answer": answer, "reference": kbinfos}
|
264 |
|
265 |
|
266 |
-
def use_sql(question, field_map, tenant_id, chat_mdl):
|
267 |
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
268 |
user_promt = """
|
269 |
表名:{};
|
@@ -353,12 +354,16 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
353 |
# compose markdown table
|
354 |
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
355 |
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
|
|
356 |
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
357 |
("|------|" if docid_idx and docid_idx else "")
|
|
|
358 |
rows = ["|" +
|
359 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
360 |
"|" for r in tbl["rows"]]
|
361 |
-
|
|
|
|
|
362 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
363 |
|
364 |
if not docid_idx or not docnm_idx:
|
|
|
193 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
194 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
195 |
|
196 |
+
prompt_config = dialog.prompt_config
|
197 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
198 |
# try to use sql if field mapping is good to go
|
199 |
if field_map:
|
200 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
201 |
+
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
202 |
if ans: return ans
|
203 |
|
|
|
204 |
for p in prompt_config["parameters"]:
|
205 |
if p["key"] == "knowledge":
|
206 |
continue
|
|
|
255 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
256 |
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
257 |
kbinfos["doc_aggs"] = recall_docs
|
258 |
+
|
259 |
for c in kbinfos["chunks"]:
|
260 |
if c.get("vector"):
|
261 |
del c["vector"]
|
|
|
264 |
return {"answer": answer, "reference": kbinfos}
|
265 |
|
266 |
|
267 |
+
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
268 |
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
269 |
user_promt = """
|
270 |
表名:{};
|
|
|
354 |
# compose markdown table
|
355 |
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
356 |
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
357 |
+
|
358 |
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
359 |
("|------|" if docid_idx and docid_idx else "")
|
360 |
+
|
361 |
rows = ["|" +
|
362 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
363 |
"|" for r in tbl["rows"]]
|
364 |
+
if quota:
|
365 |
+
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
366 |
+
else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
367 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
368 |
|
369 |
if not docid_idx or not docnm_idx:
|
api/db/init_data.py
CHANGED
@@ -159,6 +159,12 @@ def init_llm_factory():
|
|
159 |
"max_tokens": 8191,
|
160 |
"model_type": LLMType.CHAT.value
|
161 |
}, {
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
"fid": factory_infos[0]["name"],
|
163 |
"llm_name": "gpt-4-32k",
|
164 |
"tags": "LLM,CHAT,32K",
|
|
|
159 |
"max_tokens": 8191,
|
160 |
"model_type": LLMType.CHAT.value
|
161 |
}, {
|
162 |
+
"fid": factory_infos[0]["name"],
|
163 |
+
"llm_name": "gpt-4-turbo",
|
164 |
+
"tags": "LLM,CHAT,8K",
|
165 |
+
"max_tokens": 8191,
|
166 |
+
"model_type": LLMType.CHAT.value
|
167 |
+
},{
|
168 |
"fid": factory_infos[0]["name"],
|
169 |
"llm_name": "gpt-4-32k",
|
170 |
"tags": "LLM,CHAT,32K",
|