Kevin Hu commited on
Commit
0dbe613
1 Parent(s): e92c050

fix sequence2txt error and usage total token issue (#2961)

Browse files

### What problem does this PR solve?

#1363

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

api/apps/conversation_app.py CHANGED
@@ -26,7 +26,6 @@ from api.db.services.dialog_service import DialogService, ConversationService, c
26
  from api.db.services.knowledgebase_service import KnowledgebaseService
27
  from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
28
  from api.settings import RetCode, retrievaler
29
- from api.utils import get_uuid
30
  from api.utils.api_utils import get_json_result
31
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
32
  from graphrag.mind_map_extractor import MindMapExtractor
@@ -187,6 +186,7 @@ def completion():
187
  yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
188
  ConversationService.update_by_id(conv.id, conv.to_dict())
189
  except Exception as e:
 
190
  yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
191
  "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
192
  ensure_ascii=False) + "\n\n"
 
26
  from api.db.services.knowledgebase_service import KnowledgebaseService
27
  from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
28
  from api.settings import RetCode, retrievaler
 
29
  from api.utils.api_utils import get_json_result
30
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
31
  from graphrag.mind_map_extractor import MindMapExtractor
 
186
  yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
187
  ConversationService.update_by_id(conv.id, conv.to_dict())
188
  except Exception as e:
189
+ traceback.print_exc()
190
  yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
191
  "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
192
  ensure_ascii=False) + "\n\n"
api/db/services/llm_service.py CHANGED
@@ -133,7 +133,8 @@ class TenantLLMService(CommonService):
133
  if model_config["llm_factory"] not in Seq2txtModel:
134
  return
135
  return Seq2txtModel[model_config["llm_factory"]](
136
- model_config["api_key"], model_config["llm_name"], lang,
 
137
  base_url=model_config["api_base"]
138
  )
139
  if llm_type == LLMType.TTS:
 
133
  if model_config["llm_factory"] not in Seq2txtModel:
134
  return
135
  return Seq2txtModel[model_config["llm_factory"]](
136
+ key=model_config["api_key"], model_name=model_config["llm_name"],
137
+ lang=lang,
138
  base_url=model_config["api_base"]
139
  )
140
  if llm_type == LLMType.TTS:
api/utils/file_utils.py CHANGED
@@ -197,6 +197,7 @@ def thumbnail_img(filename, blob):
197
  pass
198
  return None
199
 
 
200
  def thumbnail(filename, blob):
201
  img = thumbnail_img(filename, blob)
202
  if img is not None:
@@ -205,6 +206,7 @@ def thumbnail(filename, blob):
205
  else:
206
  return ''
207
 
 
208
  def traversal_files(base):
209
  for root, ds, fs in os.walk(base):
210
  for f in fs:
 
197
  pass
198
  return None
199
 
200
+
201
  def thumbnail(filename, blob):
202
  img = thumbnail_img(filename, blob)
203
  if img is not None:
 
206
  else:
207
  return ''
208
 
209
+
210
  def traversal_files(base):
211
  for root, ds, fs in os.walk(base):
212
  for f in fs:
rag/llm/chat_model.py CHANGED
@@ -67,14 +67,16 @@ class Base(ABC):
67
  if not resp.choices[0].delta.content:
68
  resp.choices[0].delta.content = ""
69
  ans += resp.choices[0].delta.content
70
- total_tokens = (
71
- (
72
- total_tokens
73
- + num_tokens_from_string(resp.choices[0].delta.content)
74
- )
75
- if not hasattr(resp, "usage") or not resp.usage
76
- else resp.usage.get("total_tokens", total_tokens)
77
- )
 
 
78
  if resp.choices[0].finish_reason == "length":
79
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
80
  [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
 
67
  if not resp.choices[0].delta.content:
68
  resp.choices[0].delta.content = ""
69
  ans += resp.choices[0].delta.content
70
+ total_tokens += 1
71
+ if not hasattr(resp, "usage") or not resp.usage:
72
+ total_tokens = (
73
+ total_tokens
74
+ + num_tokens_from_string(resp.choices[0].delta.content)
75
+ )
76
+ elif isinstance(resp.usage, dict):
77
+ total_tokens = resp.usage.get("total_tokens", total_tokens)
78
+ else: total_tokens = resp.usage.total_tokens
79
+
80
  if resp.choices[0].finish_reason == "length":
81
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
82
  [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
rag/llm/sequence2txt_model.py CHANGED
@@ -87,7 +87,7 @@ class AzureSeq2txt(Base):
87
 
88
 
89
  class XinferenceSeq2txt(Base):
90
- def __init__(self,key,model_name="whisper-small",**kwargs):
91
  self.base_url = kwargs.get('base_url', None)
92
  self.model_name = model_name
93
  self.key = key
 
87
 
88
 
89
  class XinferenceSeq2txt(Base):
90
+ def __init__(self, key, model_name="whisper-small", **kwargs):
91
  self.base_url = kwargs.get('base_url', None)
92
  self.model_name = model_name
93
  self.key = key