Kevin Hu
commited on
Commit
路
0dbe613
1
Parent(s):
e92c050
fix sequence2txt error and usage total token issue (#2961)
Browse files### What problem does this PR solve?
#1363
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/conversation_app.py +1 -1
- api/db/services/llm_service.py +2 -1
- api/utils/file_utils.py +2 -0
- rag/llm/chat_model.py +10 -8
- rag/llm/sequence2txt_model.py +1 -1
api/apps/conversation_app.py
CHANGED
@@ -26,7 +26,6 @@ from api.db.services.dialog_service import DialogService, ConversationService, c
|
|
26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
27 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
28 |
from api.settings import RetCode, retrievaler
|
29 |
-
from api.utils import get_uuid
|
30 |
from api.utils.api_utils import get_json_result
|
31 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
32 |
from graphrag.mind_map_extractor import MindMapExtractor
|
@@ -187,6 +186,7 @@ def completion():
|
|
187 |
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
188 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
189 |
except Exception as e:
|
|
|
190 |
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
191 |
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
192 |
ensure_ascii=False) + "\n\n"
|
|
|
26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
27 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
28 |
from api.settings import RetCode, retrievaler
|
|
|
29 |
from api.utils.api_utils import get_json_result
|
30 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
31 |
from graphrag.mind_map_extractor import MindMapExtractor
|
|
|
186 |
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
187 |
ConversationService.update_by_id(conv.id, conv.to_dict())
|
188 |
except Exception as e:
|
189 |
+
traceback.print_exc()
|
190 |
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
191 |
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
192 |
ensure_ascii=False) + "\n\n"
|
api/db/services/llm_service.py
CHANGED
@@ -133,7 +133,8 @@ class TenantLLMService(CommonService):
|
|
133 |
if model_config["llm_factory"] not in Seq2txtModel:
|
134 |
return
|
135 |
return Seq2txtModel[model_config["llm_factory"]](
|
136 |
-
model_config["api_key"], model_config["llm_name"],
|
|
|
137 |
base_url=model_config["api_base"]
|
138 |
)
|
139 |
if llm_type == LLMType.TTS:
|
|
|
133 |
if model_config["llm_factory"] not in Seq2txtModel:
|
134 |
return
|
135 |
return Seq2txtModel[model_config["llm_factory"]](
|
136 |
+
key=model_config["api_key"], model_name=model_config["llm_name"],
|
137 |
+
lang=lang,
|
138 |
base_url=model_config["api_base"]
|
139 |
)
|
140 |
if llm_type == LLMType.TTS:
|
api/utils/file_utils.py
CHANGED
@@ -197,6 +197,7 @@ def thumbnail_img(filename, blob):
|
|
197 |
pass
|
198 |
return None
|
199 |
|
|
|
200 |
def thumbnail(filename, blob):
|
201 |
img = thumbnail_img(filename, blob)
|
202 |
if img is not None:
|
@@ -205,6 +206,7 @@ def thumbnail(filename, blob):
|
|
205 |
else:
|
206 |
return ''
|
207 |
|
|
|
208 |
def traversal_files(base):
|
209 |
for root, ds, fs in os.walk(base):
|
210 |
for f in fs:
|
|
|
197 |
pass
|
198 |
return None
|
199 |
|
200 |
+
|
201 |
def thumbnail(filename, blob):
|
202 |
img = thumbnail_img(filename, blob)
|
203 |
if img is not None:
|
|
|
206 |
else:
|
207 |
return ''
|
208 |
|
209 |
+
|
210 |
def traversal_files(base):
|
211 |
for root, ds, fs in os.walk(base):
|
212 |
for f in fs:
|
rag/llm/chat_model.py
CHANGED
@@ -67,14 +67,16 @@ class Base(ABC):
|
|
67 |
if not resp.choices[0].delta.content:
|
68 |
resp.choices[0].delta.content = ""
|
69 |
ans += resp.choices[0].delta.content
|
70 |
-
total_tokens
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
if resp.choices[0].finish_reason == "length":
|
79 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
80 |
[ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
|
|
|
67 |
if not resp.choices[0].delta.content:
|
68 |
resp.choices[0].delta.content = ""
|
69 |
ans += resp.choices[0].delta.content
|
70 |
+
total_tokens += 1
|
71 |
+
if not hasattr(resp, "usage") or not resp.usage:
|
72 |
+
total_tokens = (
|
73 |
+
total_tokens
|
74 |
+
+ num_tokens_from_string(resp.choices[0].delta.content)
|
75 |
+
)
|
76 |
+
elif isinstance(resp.usage, dict):
|
77 |
+
total_tokens = resp.usage.get("total_tokens", total_tokens)
|
78 |
+
else: total_tokens = resp.usage.total_tokens
|
79 |
+
|
80 |
if resp.choices[0].finish_reason == "length":
|
81 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
82 |
[ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
|
rag/llm/sequence2txt_model.py
CHANGED
@@ -87,7 +87,7 @@ class AzureSeq2txt(Base):
|
|
87 |
|
88 |
|
89 |
class XinferenceSeq2txt(Base):
|
90 |
-
def __init__(self,key,model_name="whisper-small"
|
91 |
self.base_url = kwargs.get('base_url', None)
|
92 |
self.model_name = model_name
|
93 |
self.key = key
|
|
|
87 |
|
88 |
|
89 |
class XinferenceSeq2txt(Base):
|
90 |
+
def __init__(self, key, model_name="whisper-small", **kwargs):
|
91 |
self.base_url = kwargs.get('base_url', None)
|
92 |
self.model_name = model_name
|
93 |
self.key = key
|