Spaces:
Running
Running
martin
commited on
Commit
·
b007bca
1
Parent(s):
930f36f
update app
Browse files- .gitattributes +5 -0
- app.py +204 -83
- assets/request_rap_ZH.wav +3 -0
- assets/tell_me_a_short_story_EN.wav +3 -0
- assets/yuewen.jpeg +3 -0
- speakers/{闫雨婷RAP_prompt.wav → TingtingRAP_prompt.wav} +0 -0
- speakers/{闫雨婷VOCAL_prompt.wav → TingtingVOCAL_prompt.wav} +0 -0
- speakers/{闫雨婷_prompt.wav → Tingting_prompt.wav} +0 -0
- speakers/speakers_info.json +3 -3
- start_app.sh +1 -2
- stepaudio.py +0 -95
- tts.py +13 -0
- yuewen_api.py +54 -0
.gitattributes
CHANGED
@@ -5,3 +5,8 @@ assets/assistant.png filter=lfs diff=lfs merge=lfs -text
|
|
5 |
speakers/闫雨婷_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
6 |
speakers/闫雨婷RAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
7 |
speakers/闫雨婷VOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
5 |
speakers/闫雨婷_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
6 |
speakers/闫雨婷RAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
7 |
speakers/闫雨婷VOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
8 |
+
speakers/Tingting_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
9 |
+
speakers/TingtingRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
10 |
+
speakers/TingtingVOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
11 |
+
assets/yuewen.jpeg filter=lfs diff=lfs merge=lfs -text
|
12 |
+
assets/request_rap_zh.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,61 +1,97 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import time
|
3 |
from pathlib import Path
|
4 |
-
import torchaudio
|
5 |
-
from stepaudio import StepAudio
|
6 |
|
7 |
-
from
|
8 |
-
from
|
|
|
9 |
|
10 |
CACHE_DIR = "/tmp/gradio/"
|
11 |
-
|
12 |
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
model=model_name,
|
18 |
-
vad_model="fsmn-vad",
|
19 |
-
vad_kwargs={"max_single_segment_time": 30000},
|
20 |
-
device=device,
|
21 |
-
)
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
def add_message(chatbot, history, mic, text):
|
38 |
if not mic and not text:
|
39 |
return chatbot, history, "Input is empty"
|
40 |
|
|
|
|
|
41 |
if text:
|
42 |
chatbot.append({"role": "user", "content": text})
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
chatbot.append({"role": "user", "content": {"path": mic}})
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
return chatbot, history, None
|
50 |
|
51 |
|
52 |
-
def reset_state():
|
53 |
-
"""Reset the chat history."""
|
54 |
-
return [], [system_promtp]
|
55 |
-
|
56 |
-
|
57 |
def save_tmp_audio(audio, sr):
|
58 |
import tempfile
|
|
|
59 |
|
60 |
with tempfile.NamedTemporaryFile(
|
61 |
dir=CACHE_DIR, delete=False, suffix=".wav"
|
@@ -66,90 +102,176 @@ def save_tmp_audio(audio, sr):
|
|
66 |
return temp_audio.name
|
67 |
|
68 |
|
69 |
-
def predict(chatbot, history,
|
70 |
"""Generate a response from the model."""
|
|
|
71 |
try:
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
print(f"predict {text=}")
|
|
|
80 |
audio_path = save_tmp_audio(audio, sr)
|
81 |
-
|
82 |
-
if is_input_audio:
|
83 |
-
asr_text = asr_model.run(user_audio_path)
|
84 |
-
chatbot.append({"role": "user", "content": asr_text})
|
85 |
-
history[-1]["content"] = asr_text
|
86 |
-
print(f"{asr_text=}")
|
87 |
-
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
|
88 |
chatbot.append({"role": "assistant", "content": text})
|
|
|
89 |
history.append({"role": "assistant", "content": text})
|
90 |
except Exception as e:
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
return chatbot, history
|
94 |
|
95 |
|
96 |
-
def _launch_demo(args,
|
97 |
-
with gr.Blocks(delete_cache=(86400,
|
|
|
|
|
98 |
gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
chatbot = gr.Chatbot(
|
100 |
elem_id="chatbot",
|
101 |
avatar_images=["assets/user.png", "assets/assistant.png"],
|
102 |
min_height=800,
|
103 |
type="messages",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
)
|
105 |
-
# 保存 chat 历史,不需要每次再重新拼格式
|
106 |
-
history = gr.State([system_promtp])
|
107 |
-
mic = gr.Audio(type="filepath")
|
108 |
-
text = gr.Textbox(placeholder="Enter message ...")
|
109 |
-
|
110 |
with gr.Row():
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
regen_btn = gr.Button("🤔️ Regenerate (重试)")
|
113 |
-
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
)
|
|
|
|
|
|
|
119 |
if error:
|
120 |
-
gr.Warning(error)
|
121 |
return chatbot, history, None, None
|
122 |
else:
|
123 |
-
chatbot, history = predict(chatbot, history,
|
124 |
return chatbot, history, None, None
|
125 |
|
126 |
-
|
127 |
fn=on_submit,
|
128 |
-
inputs=[chatbot, history, mic, text],
|
129 |
outputs=[chatbot, history, mic, text],
|
130 |
-
|
131 |
-
concurrency_id="gpu_queue",
|
132 |
)
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
outputs=[chatbot, history],
|
136 |
show_progress=True,
|
137 |
)
|
138 |
|
139 |
-
def regenerate(chatbot, history):
|
140 |
while chatbot and chatbot[-1]["role"] == "assistant":
|
141 |
chatbot.pop()
|
142 |
while history and history[-1]["role"] == "assistant":
|
143 |
print(f"discard {history[-1]}")
|
144 |
history.pop()
|
145 |
-
return predict(chatbot, history,
|
146 |
|
147 |
regen_btn.click(
|
148 |
regenerate,
|
149 |
-
[chatbot, history],
|
150 |
[chatbot, history],
|
151 |
show_progress=True,
|
152 |
-
concurrency_id="gpu_queue",
|
153 |
)
|
154 |
|
155 |
demo.queue().launch(
|
@@ -164,7 +286,7 @@ if __name__ == "__main__":
|
|
164 |
import os
|
165 |
|
166 |
parser = ArgumentParser()
|
167 |
-
parser.add_argument("--model-path", type=str,
|
168 |
parser.add_argument(
|
169 |
"--server-port", type=int, default=7860, help="Demo server port."
|
170 |
)
|
@@ -172,11 +294,10 @@ if __name__ == "__main__":
|
|
172 |
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
|
173 |
)
|
174 |
args = parser.parse_args()
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
)
|
181 |
-
|
182 |
-
_launch_demo(args, audio_model, asr_model)
|
|
|
1 |
+
import base64
|
2 |
+
from copy import deepcopy
|
3 |
import gradio as gr
|
4 |
import time
|
5 |
from pathlib import Path
|
|
|
|
|
6 |
|
7 |
+
from tokenizer import StepAudioTokenizer
|
8 |
+
from tts import StepAudioTTS
|
9 |
+
from yuewen_api import call_audiochat, call_asr
|
10 |
|
11 |
CACHE_DIR = "/tmp/gradio/"
|
12 |
+
CACHE_CLEAN_AGE = 864000
|
13 |
|
14 |
+
CHINESE_PROMPT_CONTENT = """你是一个为对话而设计的人工智能模型,目前无法连接到互联网。
|
15 |
+
当你需要唱歌或说唱时,请以(RAP)开头。当你需要快速说话时,请以(快速)开头。当你需要慢速说话时,请以(慢速)开头。
|
16 |
+
现在,你需要倾听用户的语音内容,并以礼貌、简洁、口语化的文本进行回复。你需要尽量用户的语种进行回复。"""
|
17 |
|
18 |
+
ENGLISH_PROMPT_CONTENT = """You are an AI designed for conversation, currently unable to connect to the internet.
|
19 |
+
when you need to sing or rap, start your response with (RAP). when you need to speak fast, you start your response with (fast). when you need to speak fast, you start your response with (slow)
|
20 |
+
Now, you need to listen to the user's voice content and respond with politely, concise, conversational text. Respond in accordance with the user's language."""
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
PROMPT_TEMPLATE = {
|
23 |
+
"English Prompt": ENGLISH_PROMPT_CONTENT,
|
24 |
+
"Chinese Prompt": CHINESE_PROMPT_CONTENT,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def make_system_prompt(prompt):
|
29 |
+
return {"role": "system", "content": prompt}
|
30 |
+
|
31 |
+
|
32 |
+
BASE_DIR = Path(__file__).resolve().parent
|
33 |
+
CHAT_EXAMPLES = [
|
34 |
+
["Input audio only", None, f"{BASE_DIR}/assets/tell_me_a_short_story_EN.wav"],
|
35 |
+
[
|
36 |
+
"Input text only",
|
37 |
+
"What did the speaker mean when they said, it's raining cats and dogs?",
|
38 |
+
None,
|
39 |
+
],
|
40 |
+
[
|
41 |
+
"Text and audio mixed input",
|
42 |
+
"Answer the following query in English",
|
43 |
+
f"{BASE_DIR}/assets/request_rap_ZH.wav",
|
44 |
+
],
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def file_to_base64(filename):
|
49 |
+
with open(filename, "rb") as f:
|
50 |
+
file_content = f.read()
|
51 |
+
return base64.b64encode(file_content).decode("utf-8")
|
52 |
+
|
53 |
+
|
54 |
+
def get_audio_format(audio_filename):
|
55 |
+
extension = Path(audio_filename).suffix.lower()
|
56 |
+
if extension == ".mp3":
|
57 |
+
return "mp3"
|
58 |
+
elif extension == ".wav":
|
59 |
+
return "wav"
|
60 |
+
else:
|
61 |
+
return "other"
|
62 |
|
63 |
|
64 |
def add_message(chatbot, history, mic, text):
|
65 |
if not mic and not text:
|
66 |
return chatbot, history, "Input is empty"
|
67 |
|
68 |
+
content = []
|
69 |
+
|
70 |
if text:
|
71 |
chatbot.append({"role": "user", "content": text})
|
72 |
+
content.append({"type": "text", "text": text})
|
73 |
+
print(f"add message {text=}")
|
74 |
+
|
75 |
+
if mic and Path(mic).exists():
|
76 |
chatbot.append({"role": "user", "content": {"path": mic}})
|
77 |
+
content.append(
|
78 |
+
{
|
79 |
+
"type": "input_audio",
|
80 |
+
"input_audio": {
|
81 |
+
"data": file_to_base64(mic),
|
82 |
+
"format": get_audio_format(mic),
|
83 |
+
},
|
84 |
+
}
|
85 |
+
)
|
86 |
+
print(f"add message {mic=}")
|
87 |
|
88 |
+
history.append({"role": "user", "content": content})
|
89 |
return chatbot, history, None
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
92 |
def save_tmp_audio(audio, sr):
|
93 |
import tempfile
|
94 |
+
import torchaudio
|
95 |
|
96 |
with tempfile.NamedTemporaryFile(
|
97 |
dir=CACHE_DIR, delete=False, suffix=".wav"
|
|
|
102 |
return temp_audio.name
|
103 |
|
104 |
|
105 |
+
def predict(chatbot, history, tts_model, user_prompt, enable_asr):
|
106 |
"""Generate a response from the model."""
|
107 |
+
start_time = time.time()
|
108 |
try:
|
109 |
+
messages = [make_system_prompt(user_prompt)] + history
|
110 |
+
if enable_asr:
|
111 |
+
asr_text = None
|
112 |
+
if not isinstance(chatbot[-1]["content"], str):
|
113 |
+
user_audio_path = chatbot[-1]["content"]["path"]
|
114 |
+
print(f"do asr {user_audio_path=}")
|
115 |
+
asr_text = call_asr(user_audio_path)
|
116 |
+
print(f"{asr_text=}")
|
117 |
+
if asr_text:
|
118 |
+
last_input = deepcopy(history[-1])
|
119 |
+
last_input["content"].append({"type": "text", "text": asr_text})
|
120 |
+
messages = (
|
121 |
+
[make_system_prompt(user_prompt)] + history[:-1] + [last_input]
|
122 |
+
)
|
123 |
+
|
124 |
+
text = call_audiochat(messages)
|
125 |
print(f"predict {text=}")
|
126 |
+
audio, sr = tts_model(text, "Tingting")
|
127 |
audio_path = save_tmp_audio(audio, sr)
|
128 |
+
print(f"save_tmp_audio {audio_path=}")
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
chatbot.append({"role": "assistant", "content": text})
|
130 |
+
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
|
131 |
history.append({"role": "assistant", "content": text})
|
132 |
except Exception as e:
|
133 |
+
import traceback
|
134 |
+
|
135 |
+
traceback.print_exc()
|
136 |
+
print(f"get a excption: {e=}")
|
137 |
+
gr.Warning(f"get a excption: {str(e)}, try regenerate")
|
138 |
+
print(f"predict cost {(time.time() - start_time):.2f}s")
|
139 |
return chatbot, history
|
140 |
|
141 |
|
142 |
+
def _launch_demo(args, tts_model):
|
143 |
+
with gr.Blocks(delete_cache=(86400, CACHE_CLEAN_AGE)) as demo:
|
144 |
+
# 保存 chat 历史,不需要每次再重新拼格式
|
145 |
+
history = gr.State([])
|
146 |
gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
|
147 |
+
with gr.Row():
|
148 |
+
with gr.Column(scale=3):
|
149 |
+
gr.Markdown(
|
150 |
+
"""<font size=4>This preview demonstrates core functionalities. To unlock the complete real-time voice conversation system with end-to-end encryption and advanced features, download the Yuewen APP.</font>"""
|
151 |
+
)
|
152 |
+
with gr.Column(scale=1):
|
153 |
+
gr.Image(
|
154 |
+
height=180,
|
155 |
+
width=180,
|
156 |
+
value="assets/yuewen.jpeg",
|
157 |
+
interactive=False,
|
158 |
+
show_label=False,
|
159 |
+
show_download_button=False,
|
160 |
+
show_fullscreen_button=False,
|
161 |
+
)
|
162 |
+
with gr.Accordion(
|
163 |
+
label="The performance of English prompts is not as stable as that of Chinese prompts. You can click here to change sys prompt.", open=False
|
164 |
+
):
|
165 |
+
prompt_choice = gr.Radio(
|
166 |
+
choices=list(PROMPT_TEMPLATE.keys()),
|
167 |
+
value=list(PROMPT_TEMPLATE.keys())[0],
|
168 |
+
label="Select prompt template",
|
169 |
+
)
|
170 |
+
user_prompt = gr.Textbox(
|
171 |
+
show_label=False,
|
172 |
+
value=list(PROMPT_TEMPLATE.values())[0],
|
173 |
+
lines=6,
|
174 |
+
)
|
175 |
+
prompt_choice.change(
|
176 |
+
fn=lambda choice: PROMPT_TEMPLATE[choice],
|
177 |
+
inputs=prompt_choice,
|
178 |
+
outputs=user_prompt,
|
179 |
+
)
|
180 |
chatbot = gr.Chatbot(
|
181 |
elem_id="chatbot",
|
182 |
avatar_images=["assets/user.png", "assets/assistant.png"],
|
183 |
min_height=800,
|
184 |
type="messages",
|
185 |
+
show_share_button=True,
|
186 |
+
)
|
187 |
+
mic = gr.Audio(
|
188 |
+
label="Only MP3 and WAV formats are supported for audio uploads.",
|
189 |
+
sources=["microphone", "upload"],
|
190 |
+
type="filepath",
|
191 |
+
)
|
192 |
+
text = gr.Textbox(
|
193 |
+
placeholder="Enter message ...",
|
194 |
+
label="Input text message",
|
195 |
+
show_label=False,
|
196 |
)
|
|
|
|
|
|
|
|
|
|
|
197 |
with gr.Row():
|
198 |
+
enable_asr = gr.Checkbox(
|
199 |
+
value=True,
|
200 |
+
label="Enhance understanding capability by ASR",
|
201 |
+
)
|
202 |
+
with gr.Row():
|
203 |
+
clear_btn = gr.Button("🧹 Clear History (清除历史)")
|
204 |
regen_btn = gr.Button("🤔️ Regenerate (重试)")
|
205 |
+
gen_btn = gr.Button("🚀 Generate Response")
|
206 |
|
207 |
+
example_comment = gr.Textbox(label="Input data type", visible=False)
|
208 |
+
example_text = gr.Textbox(label="Input text message", visible=False)
|
209 |
+
example_audio = gr.Audio(
|
210 |
+
label="Input Audio",
|
211 |
+
type="filepath",
|
212 |
+
visible=False,
|
213 |
+
)
|
214 |
+
|
215 |
+
def update_examples(_, text, mic, user_prompt, enable_asr):
|
216 |
+
chatbot = []
|
217 |
+
history = []
|
218 |
+
chatbot, history, error = add_message(chatbot, history, mic, text)
|
219 |
+
if error:
|
220 |
+
gr.Warning(error)
|
221 |
+
print(f"update_examples error")
|
222 |
+
return chatbot, history
|
223 |
+
else:
|
224 |
+
chatbot, history = predict(chatbot, history, tts_model, user_prompt, enable_asr)
|
225 |
+
print(f"update_examples done")
|
226 |
+
return chatbot, history
|
227 |
+
|
228 |
+
with gr.Row():
|
229 |
+
gr.Examples(
|
230 |
+
fn=update_examples,
|
231 |
+
examples=CHAT_EXAMPLES,
|
232 |
+
inputs=[example_comment, example_text, example_audio, user_prompt, enable_asr],
|
233 |
+
outputs=[chatbot, history],
|
234 |
+
run_on_click=True,
|
235 |
)
|
236 |
+
|
237 |
+
def on_submit(chatbot, history, mic, text, user_prompt, enable_asr):
|
238 |
+
chatbot, history, error = add_message(chatbot, history, mic, text)
|
239 |
if error:
|
240 |
+
gr.Warning(error)
|
241 |
return chatbot, history, None, None
|
242 |
else:
|
243 |
+
chatbot, history = predict(chatbot, history, tts_model, user_prompt, enable_asr)
|
244 |
return chatbot, history, None, None
|
245 |
|
246 |
+
gen_btn.click(
|
247 |
fn=on_submit,
|
248 |
+
inputs=[chatbot, history, mic, text, user_prompt, enable_asr],
|
249 |
outputs=[chatbot, history, mic, text],
|
250 |
+
show_progress=True,
|
|
|
251 |
)
|
252 |
+
|
253 |
+
def clear_chat_history():
|
254 |
+
return [], []
|
255 |
+
|
256 |
+
clear_btn.click(
|
257 |
+
clear_chat_history,
|
258 |
outputs=[chatbot, history],
|
259 |
show_progress=True,
|
260 |
)
|
261 |
|
262 |
+
def regenerate(chatbot, history, user_prompt, enable_asr):
|
263 |
while chatbot and chatbot[-1]["role"] == "assistant":
|
264 |
chatbot.pop()
|
265 |
while history and history[-1]["role"] == "assistant":
|
266 |
print(f"discard {history[-1]}")
|
267 |
history.pop()
|
268 |
+
return predict(chatbot, history, tts_model, user_prompt, enable_asr)
|
269 |
|
270 |
regen_btn.click(
|
271 |
regenerate,
|
272 |
+
[chatbot, history, user_prompt, enable_asr],
|
273 |
[chatbot, history],
|
274 |
show_progress=True,
|
|
|
275 |
)
|
276 |
|
277 |
demo.queue().launch(
|
|
|
286 |
import os
|
287 |
|
288 |
parser = ArgumentParser()
|
289 |
+
parser.add_argument("--model-path", type=str, help="Tokenizer and TTS model path.")
|
290 |
parser.add_argument(
|
291 |
"--server-port", type=int, default=7860, help="Demo server port."
|
292 |
)
|
|
|
294 |
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
|
295 |
)
|
296 |
args = parser.parse_args()
|
297 |
+
tokenizer = StepAudioTokenizer(
|
298 |
+
os.path.join(args.model_path, "Step-Audio-Tokenizer")
|
299 |
+
)
|
300 |
+
tts_model = StepAudioTTS(
|
301 |
+
os.path.join(args.model_path, "Step-Audio-TTS-3B"), tokenizer
|
302 |
)
|
303 |
+
_launch_demo(args, tts_model)
|
|
assets/request_rap_ZH.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f4507403248696260e6afb51459a39b6f8f413f2dd9f3f12568547928494c04
|
3 |
+
size 229476
|
assets/tell_me_a_short_story_EN.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66328ecbf0b612977e4ea7661c5b7707e7183d7500e1482fc7b6e012604903a5
|
3 |
+
size 80144
|
assets/yuewen.jpeg
ADDED
![]() |
Git LFS Details
|
speakers/{闫雨婷RAP_prompt.wav → TingtingRAP_prompt.wav}
RENAMED
File without changes
|
speakers/{闫雨婷VOCAL_prompt.wav → TingtingVOCAL_prompt.wav}
RENAMED
File without changes
|
speakers/{闫雨婷_prompt.wav → Tingting_prompt.wav}
RENAMED
File without changes
|
speakers/speakers_info.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"
|
3 |
-
"
|
4 |
-
"
|
5 |
}
|
|
|
1 |
{
|
2 |
+
"TingtingRAP": "(RAP)远远甩开的笑他是陆行龟 他曾跌倒也曾吃过灰 他说有福的人才会多吃亏 他的爸爸让他小心交友可他偏偏钻进个垃圾堆 他说他明白How to play",
|
3 |
+
"TingtingVOCAL": "(哼唱)你从一座叫 我 的小镇经过 刚好屋顶的雪化成雨飘落",
|
4 |
+
"Tingting": "那等我们到海洋馆之后,给妈妈买个礼物,好不好呀?"
|
5 |
}
|
start_app.sh
CHANGED
@@ -18,7 +18,6 @@ BASE_REPO_URL="https://${HF_USER_NAME}:${HF_USER_TOKEN}@huggingface.co/stepfun-a
|
|
18 |
REPOSITORIES=(
|
19 |
"Step-Audio-Tokenizer"
|
20 |
"Step-Audio-TTS-3B"
|
21 |
-
"Step-Audio-Chat"
|
22 |
)
|
23 |
|
24 |
# 定义本地存放仓库的目录,默认为当前目录
|
@@ -55,5 +54,5 @@ for repo in "${REPOSITORIES[@]}"; do
|
|
55 |
done
|
56 |
|
57 |
echo "所有仓库已成功下载!"
|
58 |
-
|
59 |
python app.py --model $LOCAL_DIR
|
|
|
18 |
REPOSITORIES=(
|
19 |
"Step-Audio-Tokenizer"
|
20 |
"Step-Audio-TTS-3B"
|
|
|
21 |
)
|
22 |
|
23 |
# 定义本地存放仓库的目录,默认为当前目录
|
|
|
54 |
done
|
55 |
|
56 |
echo "所有仓库已成功下载!"
|
57 |
+
export LD_LIBRARY_PATH=/usr/local/lib/python3.10/dist-packages/nvidia/cuda_nvrtc/lib:$LD_LIBRARY_PATH
|
58 |
python app.py --model $LOCAL_DIR
|
stepaudio.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torchaudio
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
-
|
5 |
-
from tokenizer import StepAudioTokenizer
|
6 |
-
from tts import StepAudioTTS
|
7 |
-
from utils import load_audio, speech_adjust, volumn_adjust
|
8 |
-
|
9 |
-
|
10 |
-
class StepAudio:
|
11 |
-
def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str):
|
12 |
-
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
13 |
-
llm_path, trust_remote_code=True
|
14 |
-
)
|
15 |
-
self.encoder = StepAudioTokenizer(tokenizer_path)
|
16 |
-
self.decoder = StepAudioTTS(tts_path, self.encoder)
|
17 |
-
self.llm = AutoModelForCausalLM.from_pretrained(
|
18 |
-
llm_path,
|
19 |
-
torch_dtype=torch.bfloat16,
|
20 |
-
device_map="auto",
|
21 |
-
trust_remote_code=True,
|
22 |
-
)
|
23 |
-
|
24 |
-
def __call__(
|
25 |
-
self,
|
26 |
-
messages: list,
|
27 |
-
speaker_id: str,
|
28 |
-
speed_ratio: float = 1.0,
|
29 |
-
volumn_ratio: float = 1.0,
|
30 |
-
):
|
31 |
-
text_with_audio = self.apply_chat_template(messages)
|
32 |
-
token_ids = self.llm_tokenizer.encode(text_with_audio, return_tensors="pt")
|
33 |
-
outputs = self.llm.generate(
|
34 |
-
token_ids, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True
|
35 |
-
)
|
36 |
-
output_token_ids = outputs[:, token_ids.shape[-1] : -1].tolist()[0]
|
37 |
-
output_text = self.llm_tokenizer.decode(output_token_ids)
|
38 |
-
output_audio, sr = self.decoder(output_text, speaker_id)
|
39 |
-
if speed_ratio != 1.0:
|
40 |
-
output_audio = speech_adjust(output_audio, sr, speed_ratio)
|
41 |
-
if volumn_ratio != 1.0:
|
42 |
-
output_audio = volumn_adjust(output_audio, volumn_ratio)
|
43 |
-
return output_text, output_audio, sr
|
44 |
-
|
45 |
-
def encode_audio(self, audio_path):
|
46 |
-
audio_wav, sr = load_audio(audio_path)
|
47 |
-
audio_tokens = self.encoder(audio_wav, sr)
|
48 |
-
return audio_tokens
|
49 |
-
|
50 |
-
def apply_chat_template(self, messages: list):
|
51 |
-
text_with_audio = ""
|
52 |
-
for msg in messages:
|
53 |
-
role = msg["role"]
|
54 |
-
content = msg["content"]
|
55 |
-
if role == "user":
|
56 |
-
role = "human"
|
57 |
-
if isinstance(content, str):
|
58 |
-
text_with_audio += f"<|BOT|>{role}\n{content}<|EOT|>"
|
59 |
-
elif isinstance(content, dict):
|
60 |
-
if content["type"] == "text":
|
61 |
-
text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>"
|
62 |
-
elif content["type"] == "audio":
|
63 |
-
audio_tokens = self.encode_audio(content["audio"])
|
64 |
-
text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>"
|
65 |
-
elif content is None:
|
66 |
-
text_with_audio += f"<|BOT|>{role}\n"
|
67 |
-
else:
|
68 |
-
raise ValueError(f"Unsupported content type: {type(content)}")
|
69 |
-
if not text_with_audio.endswith("<|BOT|>assistant\n"):
|
70 |
-
text_with_audio += "<|BOT|>assistant\n"
|
71 |
-
return text_with_audio
|
72 |
-
|
73 |
-
|
74 |
-
if __name__ == "__main__":
|
75 |
-
model = StepAudio(
|
76 |
-
encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder",
|
77 |
-
decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder",
|
78 |
-
llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18",
|
79 |
-
)
|
80 |
-
|
81 |
-
text, audio, sr = model(
|
82 |
-
[{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}],
|
83 |
-
"闫雨婷",
|
84 |
-
)
|
85 |
-
torchaudio.save("output/output_e2e_tqta.wav", audio, sr)
|
86 |
-
text, audio, sr = model(
|
87 |
-
[
|
88 |
-
{
|
89 |
-
"role": "user",
|
90 |
-
"content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"},
|
91 |
-
}
|
92 |
-
],
|
93 |
-
"闫雨婷",
|
94 |
-
)
|
95 |
-
torchaudio.save("output/output_e2e_aqta.wav", audio, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts.py
CHANGED
@@ -37,6 +37,19 @@ class StepAudioTTS:
|
|
37 |
model_path,
|
38 |
encoder,
|
39 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
self.llm = AutoModelForCausalLM.from_pretrained(
|
41 |
model_path,
|
42 |
torch_dtype=torch.bfloat16,
|
|
|
37 |
model_path,
|
38 |
encoder,
|
39 |
):
|
40 |
+
# load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib`
|
41 |
+
# if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib
|
42 |
+
try:
|
43 |
+
if torch.__version__ >= "2.5":
|
44 |
+
torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'))
|
45 |
+
elif torch.__version__ >= "2.3":
|
46 |
+
torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so'))
|
47 |
+
elif torch.__version__ >= "2.2":
|
48 |
+
torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so'))
|
49 |
+
print("Load optimus_ths successfully and flash attn would be enabled")
|
50 |
+
except Exception as err:
|
51 |
+
print(f"Fail to load optimus_ths and flash attn is disabled: {err}")
|
52 |
+
|
53 |
self.llm = AutoModelForCausalLM.from_pretrained(
|
54 |
model_path,
|
55 |
torch_dtype=torch.bfloat16,
|
yuewen_api.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI, APIStatusError
|
2 |
+
import os
|
3 |
+
|
4 |
+
AUDIO_CHAT_MODEL = os.getenv("STEP_AUDIO_CHAT", "step-1o-audio")
|
5 |
+
TTS_MODEL = os.getenv("STEP_AUDIO_TTS", "step-tts-mini")
|
6 |
+
ASR_MODEL = os.getenv("STEP_AUDIO_ASR", "step-asr")
|
7 |
+
STEP_BASE_URL = os.getenv("STEP_BASE_URL", "https://api.stepfun.com/v1")
|
8 |
+
|
9 |
+
client = OpenAI(
|
10 |
+
base_url=STEP_BASE_URL,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def call_audiochat(messages):
|
15 |
+
try:
|
16 |
+
completion = client.chat.completions.create(
|
17 |
+
model=AUDIO_CHAT_MODEL,
|
18 |
+
messages=messages,
|
19 |
+
presence_penalty=1,
|
20 |
+
)
|
21 |
+
return completion.choices[0].message.content
|
22 |
+
except APIStatusError as e:
|
23 |
+
print(e)
|
24 |
+
raise RuntimeError(e)
|
25 |
+
except Exception as e:
|
26 |
+
raise e
|
27 |
+
|
28 |
+
|
29 |
+
def call_tts(text, audio_path, voice="qinqienvsheng"):
|
30 |
+
response = client.audio.speech.create(model=TTS_MODEL, voice=voice, input=text)
|
31 |
+
response.stream_to_file(audio_path)
|
32 |
+
return True
|
33 |
+
|
34 |
+
|
35 |
+
def call_asr(audio_path):
|
36 |
+
with open(audio_path, "rb") as audio_file:
|
37 |
+
response = client.audio.transcriptions.create(
|
38 |
+
model=ASR_MODEL,
|
39 |
+
file=audio_file,
|
40 |
+
response_format="json",
|
41 |
+
)
|
42 |
+
return response.text
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
messages = [{"role":"user", "content": "介绍下你自己"}]
|
46 |
+
res = call_audiochat(messages)
|
47 |
+
print("call audiochat: ", res)
|
48 |
+
|
49 |
+
audio_path = "test.mp3"
|
50 |
+
text = "hello, 阶跃"
|
51 |
+
call_tts(text, audio_path)
|
52 |
+
|
53 |
+
res = call_asr(audio_path)
|
54 |
+
print("call asr:", res)
|