liuhua
liuhua
commited on
Commit
·
e6abe77
1
Parent(s):
7362294
SparkTTS (#2535)
Browse files### What problem does this PR solve?
SparkTTS
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: liuhua <[email protected]>
- api/apps/llm_app.py +4 -1
- rag/llm/__init__.py +2 -1
- rag/llm/tts_model.py +118 -6
- requirements.txt +2 -0
- web/src/locales/en.ts +6 -0
- web/src/locales/zh-traditional.ts +6 -0
- web/src/locales/zh.ts +6 -0
- web/src/pages/user-setting/setting-model/spark-modal/index.tsx +56 -14
api/apps/llm_app.py
CHANGED
@@ -161,7 +161,10 @@ def add_llm():
|
|
161 |
|
162 |
elif factory =="XunFei Spark":
|
163 |
llm_name = req["llm_name"]
|
164 |
-
|
|
|
|
|
|
|
165 |
|
166 |
elif factory == "BaiduYiyan":
|
167 |
llm_name = req["llm_name"]
|
|
|
161 |
|
162 |
elif factory =="XunFei Spark":
|
163 |
llm_name = req["llm_name"]
|
164 |
+
if req["model_type"] == "chat":
|
165 |
+
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
|
166 |
+
elif req["model_type"] == "tts":
|
167 |
+
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])
|
168 |
|
169 |
elif factory == "BaiduYiyan":
|
170 |
llm_name = req["llm_name"]
|
rag/llm/__init__.py
CHANGED
@@ -139,5 +139,6 @@ Seq2txtModel = {
|
|
139 |
TTSModel = {
|
140 |
"Fish Audio": FishAudioTTS,
|
141 |
"Tongyi-Qianwen": QwenTTS,
|
142 |
-
"OpenAI":OpenAITTS
|
|
|
143 |
}
|
|
|
139 |
TTSModel = {
|
140 |
"Fish Audio": FishAudioTTS,
|
141 |
"Tongyi-Qianwen": QwenTTS,
|
142 |
+
"OpenAI":OpenAITTS,
|
143 |
+
"XunFei Spark":SparkTTS
|
144 |
}
|
rag/llm/tts_model.py
CHANGED
@@ -14,16 +14,30 @@
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
|
17 |
-
import
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from abc import ABC
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
import httpx
|
21 |
import ormsgpack
|
|
|
|
|
22 |
from pydantic import BaseModel, conint
|
|
|
23 |
from rag.utils import num_tokens_from_string
|
24 |
-
import json
|
25 |
-
import re
|
26 |
-
import time
|
27 |
|
28 |
|
29 |
class ServeReferenceAudio(BaseModel):
|
@@ -161,7 +175,7 @@ class QwenTTS(Base):
|
|
161 |
|
162 |
class OpenAITTS(Base):
|
163 |
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
164 |
-
if not base_url: base_url="https://api.openai.com/v1"
|
165 |
self.api_key = key
|
166 |
self.model_name = model_name
|
167 |
self.base_url = base_url
|
@@ -185,3 +199,101 @@ class OpenAITTS(Base):
|
|
185 |
for chunk in response.iter_content():
|
186 |
if chunk:
|
187 |
yield chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
|
17 |
+
import _thread as thread
|
18 |
+
import base64
|
19 |
+
import datetime
|
20 |
+
import hashlib
|
21 |
+
import hmac
|
22 |
+
import json
|
23 |
+
import queue
|
24 |
+
import re
|
25 |
+
import ssl
|
26 |
+
import time
|
27 |
from abc import ABC
|
28 |
+
from datetime import datetime
|
29 |
+
from time import mktime
|
30 |
+
from typing import Annotated, Literal
|
31 |
+
from urllib.parse import urlencode
|
32 |
+
from wsgiref.handlers import format_date_time
|
33 |
+
|
34 |
import httpx
|
35 |
import ormsgpack
|
36 |
+
import requests
|
37 |
+
import websocket
|
38 |
from pydantic import BaseModel, conint
|
39 |
+
|
40 |
from rag.utils import num_tokens_from_string
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
class ServeReferenceAudio(BaseModel):
|
|
|
175 |
|
176 |
class OpenAITTS(Base):
|
177 |
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
178 |
+
if not base_url: base_url = "https://api.openai.com/v1"
|
179 |
self.api_key = key
|
180 |
self.model_name = model_name
|
181 |
self.base_url = base_url
|
|
|
199 |
for chunk in response.iter_content():
|
200 |
if chunk:
|
201 |
yield chunk
|
202 |
+
|
203 |
+
|
204 |
+
class SparkTTS:
|
205 |
+
STATUS_FIRST_FRAME = 0
|
206 |
+
STATUS_CONTINUE_FRAME = 1
|
207 |
+
STATUS_LAST_FRAME = 2
|
208 |
+
|
209 |
+
def __init__(self, key, model_name, base_url=""):
|
210 |
+
key = json.loads(key)
|
211 |
+
self.APPID = key.get("spark_app_id", "xxxxxxx")
|
212 |
+
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
|
213 |
+
self.APIKey = key.get("spark_api_key", "xxxxxx")
|
214 |
+
self.model_name = model_name
|
215 |
+
self.CommonArgs = {"app_id": self.APPID}
|
216 |
+
self.audio_queue = queue.Queue()
|
217 |
+
|
218 |
+
# 用来存储音频数据
|
219 |
+
|
220 |
+
# 生成url
|
221 |
+
def create_url(self):
|
222 |
+
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
223 |
+
now = datetime.now()
|
224 |
+
date = format_date_time(mktime(now.timetuple()))
|
225 |
+
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
226 |
+
signature_origin += "date: " + date + "\n"
|
227 |
+
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
228 |
+
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
229 |
+
digestmod=hashlib.sha256).digest()
|
230 |
+
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
231 |
+
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
232 |
+
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
233 |
+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
234 |
+
v = {
|
235 |
+
"authorization": authorization,
|
236 |
+
"date": date,
|
237 |
+
"host": "ws-api.xfyun.cn"
|
238 |
+
}
|
239 |
+
url = url + '?' + urlencode(v)
|
240 |
+
return url
|
241 |
+
|
242 |
+
def tts(self, text):
|
243 |
+
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
244 |
+
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
|
245 |
+
CommonArgs = {"app_id": self.APPID}
|
246 |
+
audio_queue = self.audio_queue
|
247 |
+
model_name = self.model_name
|
248 |
+
|
249 |
+
class Callback:
|
250 |
+
def __init__(self):
|
251 |
+
self.audio_queue = audio_queue
|
252 |
+
|
253 |
+
def on_message(self, ws, message):
|
254 |
+
message = json.loads(message)
|
255 |
+
code = message["code"]
|
256 |
+
sid = message["sid"]
|
257 |
+
audio = message["data"]["audio"]
|
258 |
+
audio = base64.b64decode(audio)
|
259 |
+
status = message["data"]["status"]
|
260 |
+
if status == 2:
|
261 |
+
ws.close()
|
262 |
+
if code != 0:
|
263 |
+
errMsg = message["message"]
|
264 |
+
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
|
265 |
+
else:
|
266 |
+
self.audio_queue.put(audio)
|
267 |
+
|
268 |
+
def on_error(self, ws, error):
|
269 |
+
raise Exception(error)
|
270 |
+
|
271 |
+
def on_close(self, ws, close_status_code, close_msg):
|
272 |
+
self.audio_queue.put(None) # 放入 None 作为结束标志
|
273 |
+
|
274 |
+
def on_open(self, ws):
|
275 |
+
def run(*args):
|
276 |
+
d = {"common": CommonArgs,
|
277 |
+
"business": BusinessArgs,
|
278 |
+
"data": Data}
|
279 |
+
ws.send(json.dumps(d))
|
280 |
+
|
281 |
+
thread.start_new_thread(run, ())
|
282 |
+
|
283 |
+
wsUrl = self.create_url()
|
284 |
+
websocket.enableTrace(False)
|
285 |
+
a = Callback()
|
286 |
+
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
|
287 |
+
on_message=a.on_message)
|
288 |
+
status_code = 0
|
289 |
+
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
290 |
+
while True:
|
291 |
+
audio_chunk = self.audio_queue.get()
|
292 |
+
if audio_chunk is None:
|
293 |
+
if status_code == 0:
|
294 |
+
raise Exception(
|
295 |
+
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
296 |
+
else:
|
297 |
+
break
|
298 |
+
status_code = 1
|
299 |
+
yield audio_chunk
|
requirements.txt
CHANGED
@@ -94,6 +94,8 @@ vertexai==1.64.0
|
|
94 |
volcengine==1.0.146
|
95 |
voyageai==0.2.3
|
96 |
webdriver_manager==4.0.1
|
|
|
|
|
97 |
Werkzeug==3.0.3
|
98 |
wikipedia==1.4.0
|
99 |
word2number==1.1
|
|
|
94 |
volcengine==1.0.146
|
95 |
voyageai==0.2.3
|
96 |
webdriver_manager==4.0.1
|
97 |
+
websocket==0.2.1
|
98 |
+
websocket-client==1.8.0
|
99 |
Werkzeug==3.0.3
|
100 |
wikipedia==1.4.0
|
101 |
word2number==1.1
|
web/src/locales/en.ts
CHANGED
@@ -551,6 +551,12 @@ The above is the content you need to summarize.`,
|
|
551 |
SparkModelNameMessage: 'Please select Spark model',
|
552 |
addSparkAPIPassword: 'Spark APIPassword',
|
553 |
SparkAPIPasswordMessage: 'please input your APIPassword',
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
yiyanModelNameMessage: 'Please input model name',
|
555 |
addyiyanAK: 'yiyan API KEY',
|
556 |
yiyanAKMessage: 'Please input your API KEY',
|
|
|
551 |
SparkModelNameMessage: 'Please select Spark model',
|
552 |
addSparkAPIPassword: 'Spark APIPassword',
|
553 |
SparkAPIPasswordMessage: 'please input your APIPassword',
|
554 |
+
addSparkAPPID: 'Spark APPID',
|
555 |
+
SparkAPPIDMessage: 'please input your APPID',
|
556 |
+
addSparkAPISecret: 'Spark APISecret',
|
557 |
+
SparkAPISecretMessage: 'please input your APISecret',
|
558 |
+
addSparkAPIKey: 'Spark APIKey',
|
559 |
+
SparkAPIKeyMessage: 'please input your APIKey',
|
560 |
yiyanModelNameMessage: 'Please input model name',
|
561 |
addyiyanAK: 'yiyan API KEY',
|
562 |
yiyanAKMessage: 'Please input your API KEY',
|
web/src/locales/zh-traditional.ts
CHANGED
@@ -512,6 +512,12 @@ export default {
|
|
512 |
SparkModelNameMessage: '請選擇星火模型!',
|
513 |
addSparkAPIPassword: '星火 APIPassword',
|
514 |
SparkAPIPasswordMessage: '請輸入 APIPassword',
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
yiyanModelNameMessage: '輸入模型名稱',
|
516 |
addyiyanAK: '一言 API KEY',
|
517 |
yiyanAKMessage: '請輸入 API KEY',
|
|
|
512 |
SparkModelNameMessage: '請選擇星火模型!',
|
513 |
addSparkAPIPassword: '星火 APIPassword',
|
514 |
SparkAPIPasswordMessage: '請輸入 APIPassword',
|
515 |
+
addSparkAPPID: '星火 APPID',
|
516 |
+
SparkAPPIDMessage: '請輸入 APPID',
|
517 |
+
addSparkAPISecret: '星火 APISecret',
|
518 |
+
SparkAPISecretMessage: '請輸入 APISecret',
|
519 |
+
addSparkAPIKey: '星火 APIKey',
|
520 |
+
SparkAPIKeyMessage: '請輸入 APIKey',
|
521 |
yiyanModelNameMessage: '輸入模型名稱',
|
522 |
addyiyanAK: '一言 API KEY',
|
523 |
yiyanAKMessage: '請輸入 API KEY',
|
web/src/locales/zh.ts
CHANGED
@@ -529,6 +529,12 @@ export default {
|
|
529 |
SparkModelNameMessage: '请选择星火模型!',
|
530 |
addSparkAPIPassword: '星火 APIPassword',
|
531 |
SparkAPIPasswordMessage: '请输入 APIPassword',
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
yiyanModelNameMessage: '请输入模型名称',
|
533 |
addyiyanAK: '一言 API KEY',
|
534 |
yiyanAKMessage: '请输入 API KEY',
|
|
|
529 |
SparkModelNameMessage: '请选择星火模型!',
|
530 |
addSparkAPIPassword: '星火 APIPassword',
|
531 |
SparkAPIPasswordMessage: '请输入 APIPassword',
|
532 |
+
addSparkAPPID: '星火 APPID',
|
533 |
+
SparkAPPIDMessage: '请输入 APPID',
|
534 |
+
addSparkAPISecret: '星火 APISecret',
|
535 |
+
SparkAPISecretMessage: '请输入 APISecret',
|
536 |
+
addSparkAPIKey: '星火 APIKey',
|
537 |
+
SparkAPIKeyMessage: '请输入 APIKey',
|
538 |
yiyanModelNameMessage: '请输入模型名称',
|
539 |
addyiyanAK: '一言 API KEY',
|
540 |
yiyanAKMessage: '请输入 API KEY',
|
web/src/pages/user-setting/setting-model/spark-modal/index.tsx
CHANGED
@@ -7,6 +7,9 @@ import omit from 'lodash/omit';
|
|
7 |
type FieldType = IAddLlmRequestBody & {
|
8 |
vision: boolean;
|
9 |
spark_api_password: string;
|
|
|
|
|
|
|
10 |
};
|
11 |
|
12 |
const { Option } = Select;
|
@@ -63,28 +66,67 @@ const SparkModal = ({
|
|
63 |
>
|
64 |
<Select placeholder={t('modelTypeMessage')}>
|
65 |
<Option value="chat">chat</Option>
|
|
|
66 |
</Select>
|
67 |
</Form.Item>
|
68 |
<Form.Item<FieldType>
|
69 |
label={t('modelName')}
|
70 |
name="llm_name"
|
71 |
-
initialValue={'Spark-Max'}
|
72 |
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
|
73 |
>
|
74 |
-
<
|
75 |
-
<Option value="Spark-Max">Spark-Max</Option>
|
76 |
-
<Option value="Spark-Lite">Spark-Lite</Option>
|
77 |
-
<Option value="Spark-Pro">Spark-Pro</Option>
|
78 |
-
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
|
79 |
-
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
|
80 |
-
</Select>
|
81 |
</Form.Item>
|
82 |
-
<Form.Item
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
</Form.Item>
|
89 |
</Form>
|
90 |
</Modal>
|
|
|
7 |
type FieldType = IAddLlmRequestBody & {
|
8 |
vision: boolean;
|
9 |
spark_api_password: string;
|
10 |
+
spark_app_id: string;
|
11 |
+
spark_api_secret: string;
|
12 |
+
spark_api_key: string;
|
13 |
};
|
14 |
|
15 |
const { Option } = Select;
|
|
|
66 |
>
|
67 |
<Select placeholder={t('modelTypeMessage')}>
|
68 |
<Option value="chat">chat</Option>
|
69 |
+
<Option value="tts">tts</Option>
|
70 |
</Select>
|
71 |
</Form.Item>
|
72 |
<Form.Item<FieldType>
|
73 |
label={t('modelName')}
|
74 |
name="llm_name"
|
|
|
75 |
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
|
76 |
>
|
77 |
+
<Input placeholder={t('modelNameMessage')} />
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
</Form.Item>
|
79 |
+
<Form.Item noStyle dependencies={['model_type']}>
|
80 |
+
{({ getFieldValue }) =>
|
81 |
+
getFieldValue('model_type') === 'chat' && (
|
82 |
+
<Form.Item<FieldType>
|
83 |
+
label={t('addSparkAPIPassword')}
|
84 |
+
name="spark_api_password"
|
85 |
+
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
|
86 |
+
>
|
87 |
+
<Input placeholder={t('SparkAPIPasswordMessage')} />
|
88 |
+
</Form.Item>
|
89 |
+
)
|
90 |
+
}
|
91 |
+
</Form.Item>
|
92 |
+
<Form.Item noStyle dependencies={['model_type']}>
|
93 |
+
{({ getFieldValue }) =>
|
94 |
+
getFieldValue('model_type') === 'tts' && (
|
95 |
+
<Form.Item<FieldType>
|
96 |
+
label={t('addSparkAPPID')}
|
97 |
+
name="spark_app_id"
|
98 |
+
rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
|
99 |
+
>
|
100 |
+
<Input placeholder={t('SparkAPPIDMessage')} />
|
101 |
+
</Form.Item>
|
102 |
+
)
|
103 |
+
}
|
104 |
+
</Form.Item>
|
105 |
+
<Form.Item noStyle dependencies={['model_type']}>
|
106 |
+
{({ getFieldValue }) =>
|
107 |
+
getFieldValue('model_type') === 'tts' && (
|
108 |
+
<Form.Item<FieldType>
|
109 |
+
label={t('addSparkAPISecret')}
|
110 |
+
name="spark_api_secret"
|
111 |
+
rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
|
112 |
+
>
|
113 |
+
<Input placeholder={t('SparkAPISecretMessage')} />
|
114 |
+
</Form.Item>
|
115 |
+
)
|
116 |
+
}
|
117 |
+
</Form.Item>
|
118 |
+
<Form.Item noStyle dependencies={['model_type']}>
|
119 |
+
{({ getFieldValue }) =>
|
120 |
+
getFieldValue('model_type') === 'tts' && (
|
121 |
+
<Form.Item<FieldType>
|
122 |
+
label={t('addSparkAPIKey')}
|
123 |
+
name="spark_api_key"
|
124 |
+
rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
|
125 |
+
>
|
126 |
+
<Input placeholder={t('SparkAPIKeyMessage')} />
|
127 |
+
</Form.Item>
|
128 |
+
)
|
129 |
+
}
|
130 |
</Form.Item>
|
131 |
</Form>
|
132 |
</Modal>
|