martin commited on
Commit
b007bca
·
1 Parent(s): 930f36f

update app

Browse files
.gitattributes CHANGED
@@ -5,3 +5,8 @@ assets/assistant.png filter=lfs diff=lfs merge=lfs -text
5
  speakers/闫雨婷_prompt.wav filter=lfs diff=lfs merge=lfs -text
6
  speakers/闫雨婷RAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
7
  speakers/闫雨婷VOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
5
  speakers/闫雨婷_prompt.wav filter=lfs diff=lfs merge=lfs -text
6
  speakers/闫雨婷RAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
7
  speakers/闫雨婷VOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
8
+ speakers/Tingting_prompt.wav filter=lfs diff=lfs merge=lfs -text
9
+ speakers/TingtingRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
10
+ speakers/TingtingVOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
11
+ assets/yuewen.jpeg filter=lfs diff=lfs merge=lfs -text
12
+ assets/request_rap_zh.wav filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,61 +1,97 @@
 
 
1
  import gradio as gr
2
  import time
3
  from pathlib import Path
4
- import torchaudio
5
- from stepaudio import StepAudio
6
 
7
- from funasr import AutoModel
8
- from funasr.utils.postprocess_utils import rich_transcription_postprocess
 
9
 
10
  CACHE_DIR = "/tmp/gradio/"
11
- system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"}
12
 
 
 
 
13
 
14
- class CustomAsr:
15
- def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"):
16
- self.model = AutoModel(
17
- model=model_name,
18
- vad_model="fsmn-vad",
19
- vad_kwargs={"max_single_segment_time": 30000},
20
- device=device,
21
- )
22
 
23
- def run(self, audio_path):
24
- res = self.model.generate(
25
- input=audio_path,
26
- cache={},
27
- language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
28
- use_itn=True,
29
- batch_size_s=60,
30
- merge_vad=True, #
31
- merge_length_s=15,
32
- )
33
- text = rich_transcription_postprocess(res[0]["text"])
34
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def add_message(chatbot, history, mic, text):
38
  if not mic and not text:
39
  return chatbot, history, "Input is empty"
40
 
 
 
41
  if text:
42
  chatbot.append({"role": "user", "content": text})
43
- history.append({"role": "user", "content": text})
44
- elif mic and Path(mic).exists():
 
 
45
  chatbot.append({"role": "user", "content": {"path": mic}})
46
- history.append({"role": "user", "content": {"type":"audio", "audio": mic}})
 
 
 
 
 
 
 
 
 
47
 
48
- print(f"{history=}")
49
  return chatbot, history, None
50
 
51
 
52
- def reset_state():
53
- """Reset the chat history."""
54
- return [], [system_promtp]
55
-
56
-
57
  def save_tmp_audio(audio, sr):
58
  import tempfile
 
59
 
60
  with tempfile.NamedTemporaryFile(
61
  dir=CACHE_DIR, delete=False, suffix=".wav"
@@ -66,90 +102,176 @@ def save_tmp_audio(audio, sr):
66
  return temp_audio.name
67
 
68
 
69
- def predict(chatbot, history, audio_model, asr_model):
70
  """Generate a response from the model."""
 
71
  try:
72
- is_input_audio = False
73
- user_audio_path = None
74
- # 检测用户输入的是音频还是文本
75
- if isinstance(history[-1]["content"], dict):
76
- is_input_audio = True
77
- user_audio_path = history[-1]["content"]["audio"]
78
- text, audio, sr = audio_model(history, "闫雨婷")
 
 
 
 
 
 
 
 
 
79
  print(f"predict {text=}")
 
80
  audio_path = save_tmp_audio(audio, sr)
81
- # 缓存用户语音的 asr 文本结果为了加速下一次推理
82
- if is_input_audio:
83
- asr_text = asr_model.run(user_audio_path)
84
- chatbot.append({"role": "user", "content": asr_text})
85
- history[-1]["content"] = asr_text
86
- print(f"{asr_text=}")
87
- chatbot.append({"role": "assistant", "content": {"path": audio_path}})
88
  chatbot.append({"role": "assistant", "content": text})
 
89
  history.append({"role": "assistant", "content": text})
90
  except Exception as e:
91
- print(e)
92
- gr.Warning(f"Some error happend, retry submit")
 
 
 
 
93
  return chatbot, history
94
 
95
 
96
- def _launch_demo(args, audio_model, asr_model):
97
- with gr.Blocks(delete_cache=(86400, 86400)) as demo:
 
 
98
  gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  chatbot = gr.Chatbot(
100
  elem_id="chatbot",
101
  avatar_images=["assets/user.png", "assets/assistant.png"],
102
  min_height=800,
103
  type="messages",
 
 
 
 
 
 
 
 
 
 
 
104
  )
105
- # 保存 chat 历史,不需要每次再重新拼格式
106
- history = gr.State([system_promtp])
107
- mic = gr.Audio(type="filepath")
108
- text = gr.Textbox(placeholder="Enter message ...")
109
-
110
  with gr.Row():
111
- clean_btn = gr.Button("🧹 Clear History (清除历史)")
 
 
 
 
 
112
  regen_btn = gr.Button("🤔️ Regenerate (重试)")
113
- submit_btn = gr.Button("🚀 Submit")
114
 
115
- def on_submit(chatbot, history, mic, text):
116
- chatbot, history, error = add_message(
117
- chatbot, history, mic, text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
 
 
 
119
  if error:
120
- gr.Warning(error) # 显示警告消息
121
  return chatbot, history, None, None
122
  else:
123
- chatbot, history = predict(chatbot, history, audio_model, asr_model)
124
  return chatbot, history, None, None
125
 
126
- submit_btn.click(
127
  fn=on_submit,
128
- inputs=[chatbot, history, mic, text],
129
  outputs=[chatbot, history, mic, text],
130
- concurrency_limit=4,
131
- concurrency_id="gpu_queue",
132
  )
133
- clean_btn.click(
134
- reset_state,
 
 
 
 
135
  outputs=[chatbot, history],
136
  show_progress=True,
137
  )
138
 
139
- def regenerate(chatbot, history):
140
  while chatbot and chatbot[-1]["role"] == "assistant":
141
  chatbot.pop()
142
  while history and history[-1]["role"] == "assistant":
143
  print(f"discard {history[-1]}")
144
  history.pop()
145
- return predict(chatbot, history, audio_model, asr_model)
146
 
147
  regen_btn.click(
148
  regenerate,
149
- [chatbot, history],
150
  [chatbot, history],
151
  show_progress=True,
152
- concurrency_id="gpu_queue",
153
  )
154
 
155
  demo.queue().launch(
@@ -164,7 +286,7 @@ if __name__ == "__main__":
164
  import os
165
 
166
  parser = ArgumentParser()
167
- parser.add_argument("--model-path", type=str, required=True, help="Model path.")
168
  parser.add_argument(
169
  "--server-port", type=int, default=7860, help="Demo server port."
170
  )
@@ -172,11 +294,10 @@ if __name__ == "__main__":
172
  "--server-name", type=str, default="0.0.0.0", help="Demo server name."
173
  )
174
  args = parser.parse_args()
175
-
176
- audio_model = StepAudio(
177
- tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"),
178
- tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"),
179
- llm_path=os.path.join(args.model_path, "Step-Audio-Chat"),
180
  )
181
- asr_model = CustomAsr()
182
- _launch_demo(args, audio_model, asr_model)
 
1
+ import base64
2
+ from copy import deepcopy
3
  import gradio as gr
4
  import time
5
  from pathlib import Path
 
 
6
 
7
+ from tokenizer import StepAudioTokenizer
8
+ from tts import StepAudioTTS
9
+ from yuewen_api import call_audiochat, call_asr
10
 
11
  CACHE_DIR = "/tmp/gradio/"
12
+ CACHE_CLEAN_AGE = 864000
13
 
14
+ CHINESE_PROMPT_CONTENT = """你是一个为对话而设计的人工智能模型,目前无法连接到互联网。
15
+ 当你需要唱歌或说唱时,请以(RAP)开头。当你需要快速说话时,请以(快速)开头。当你需要慢速说话时,请以(慢速)开头。
16
+ 现在,你需要倾听用户的语音内容,并以礼貌、简洁、口语化的文本进行回复。你需要尽量用户的语种进行回复。"""
17
 
18
+ ENGLISH_PROMPT_CONTENT = """You are an AI designed for conversation, currently unable to connect to the internet.
19
+ when you need to sing or rap, start your response with (RAP). when you need to speak fast, you start your response with (fast). when you need to speak fast, you start your response with (slow)
20
+ Now, you need to listen to the user's voice content and respond with politely, concise, conversational text. Respond in accordance with the user's language."""
 
 
 
 
 
21
 
22
+ PROMPT_TEMPLATE = {
23
+ "English Prompt": ENGLISH_PROMPT_CONTENT,
24
+ "Chinese Prompt": CHINESE_PROMPT_CONTENT,
25
+ }
26
+
27
+
28
+ def make_system_prompt(prompt):
29
+ return {"role": "system", "content": prompt}
30
+
31
+
32
+ BASE_DIR = Path(__file__).resolve().parent
33
+ CHAT_EXAMPLES = [
34
+ ["Input audio only", None, f"{BASE_DIR}/assets/tell_me_a_short_story_EN.wav"],
35
+ [
36
+ "Input text only",
37
+ "What did the speaker mean when they said, it's raining cats and dogs?",
38
+ None,
39
+ ],
40
+ [
41
+ "Text and audio mixed input",
42
+ "Answer the following query in English",
43
+ f"{BASE_DIR}/assets/request_rap_ZH.wav",
44
+ ],
45
+ ]
46
+
47
+
48
+ def file_to_base64(filename):
49
+ with open(filename, "rb") as f:
50
+ file_content = f.read()
51
+ return base64.b64encode(file_content).decode("utf-8")
52
+
53
+
54
+ def get_audio_format(audio_filename):
55
+ extension = Path(audio_filename).suffix.lower()
56
+ if extension == ".mp3":
57
+ return "mp3"
58
+ elif extension == ".wav":
59
+ return "wav"
60
+ else:
61
+ return "other"
62
 
63
 
64
  def add_message(chatbot, history, mic, text):
65
  if not mic and not text:
66
  return chatbot, history, "Input is empty"
67
 
68
+ content = []
69
+
70
  if text:
71
  chatbot.append({"role": "user", "content": text})
72
+ content.append({"type": "text", "text": text})
73
+ print(f"add message {text=}")
74
+
75
+ if mic and Path(mic).exists():
76
  chatbot.append({"role": "user", "content": {"path": mic}})
77
+ content.append(
78
+ {
79
+ "type": "input_audio",
80
+ "input_audio": {
81
+ "data": file_to_base64(mic),
82
+ "format": get_audio_format(mic),
83
+ },
84
+ }
85
+ )
86
+ print(f"add message {mic=}")
87
 
88
+ history.append({"role": "user", "content": content})
89
  return chatbot, history, None
90
 
91
 
 
 
 
 
 
92
  def save_tmp_audio(audio, sr):
93
  import tempfile
94
+ import torchaudio
95
 
96
  with tempfile.NamedTemporaryFile(
97
  dir=CACHE_DIR, delete=False, suffix=".wav"
 
102
  return temp_audio.name
103
 
104
 
105
+ def predict(chatbot, history, tts_model, user_prompt, enable_asr):
106
  """Generate a response from the model."""
107
+ start_time = time.time()
108
  try:
109
+ messages = [make_system_prompt(user_prompt)] + history
110
+ if enable_asr:
111
+ asr_text = None
112
+ if not isinstance(chatbot[-1]["content"], str):
113
+ user_audio_path = chatbot[-1]["content"]["path"]
114
+ print(f"do asr {user_audio_path=}")
115
+ asr_text = call_asr(user_audio_path)
116
+ print(f"{asr_text=}")
117
+ if asr_text:
118
+ last_input = deepcopy(history[-1])
119
+ last_input["content"].append({"type": "text", "text": asr_text})
120
+ messages = (
121
+ [make_system_prompt(user_prompt)] + history[:-1] + [last_input]
122
+ )
123
+
124
+ text = call_audiochat(messages)
125
  print(f"predict {text=}")
126
+ audio, sr = tts_model(text, "Tingting")
127
  audio_path = save_tmp_audio(audio, sr)
128
+ print(f"save_tmp_audio {audio_path=}")
 
 
 
 
 
 
129
  chatbot.append({"role": "assistant", "content": text})
130
+ chatbot.append({"role": "assistant", "content": {"path": audio_path}})
131
  history.append({"role": "assistant", "content": text})
132
  except Exception as e:
133
+ import traceback
134
+
135
+ traceback.print_exc()
136
+ print(f"get a excption: {e=}")
137
+ gr.Warning(f"get a excption: {str(e)}, try regenerate")
138
+ print(f"predict cost {(time.time() - start_time):.2f}s")
139
  return chatbot, history
140
 
141
 
142
+ def _launch_demo(args, tts_model):
143
+ with gr.Blocks(delete_cache=(86400, CACHE_CLEAN_AGE)) as demo:
144
+ # 保存 chat 历史,不需要每次再重新拼格式
145
+ history = gr.State([])
146
  gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
147
+ with gr.Row():
148
+ with gr.Column(scale=3):
149
+ gr.Markdown(
150
+ """<font size=4>This preview demonstrates core functionalities. To unlock the complete real-time voice conversation system with end-to-end encryption and advanced features, download the Yuewen APP.</font>"""
151
+ )
152
+ with gr.Column(scale=1):
153
+ gr.Image(
154
+ height=180,
155
+ width=180,
156
+ value="assets/yuewen.jpeg",
157
+ interactive=False,
158
+ show_label=False,
159
+ show_download_button=False,
160
+ show_fullscreen_button=False,
161
+ )
162
+ with gr.Accordion(
163
+ label="The performance of English prompts is not as stable as that of Chinese prompts. You can click here to change sys prompt.", open=False
164
+ ):
165
+ prompt_choice = gr.Radio(
166
+ choices=list(PROMPT_TEMPLATE.keys()),
167
+ value=list(PROMPT_TEMPLATE.keys())[0],
168
+ label="Select prompt template",
169
+ )
170
+ user_prompt = gr.Textbox(
171
+ show_label=False,
172
+ value=list(PROMPT_TEMPLATE.values())[0],
173
+ lines=6,
174
+ )
175
+ prompt_choice.change(
176
+ fn=lambda choice: PROMPT_TEMPLATE[choice],
177
+ inputs=prompt_choice,
178
+ outputs=user_prompt,
179
+ )
180
  chatbot = gr.Chatbot(
181
  elem_id="chatbot",
182
  avatar_images=["assets/user.png", "assets/assistant.png"],
183
  min_height=800,
184
  type="messages",
185
+ show_share_button=True,
186
+ )
187
+ mic = gr.Audio(
188
+ label="Only MP3 and WAV formats are supported for audio uploads.",
189
+ sources=["microphone", "upload"],
190
+ type="filepath",
191
+ )
192
+ text = gr.Textbox(
193
+ placeholder="Enter message ...",
194
+ label="Input text message",
195
+ show_label=False,
196
  )
 
 
 
 
 
197
  with gr.Row():
198
+ enable_asr = gr.Checkbox(
199
+ value=True,
200
+ label="Enhance understanding capability by ASR",
201
+ )
202
+ with gr.Row():
203
+ clear_btn = gr.Button("🧹 Clear History (清除历史)")
204
  regen_btn = gr.Button("🤔️ Regenerate (重试)")
205
+ gen_btn = gr.Button("🚀 Generate Response")
206
 
207
+ example_comment = gr.Textbox(label="Input data type", visible=False)
208
+ example_text = gr.Textbox(label="Input text message", visible=False)
209
+ example_audio = gr.Audio(
210
+ label="Input Audio",
211
+ type="filepath",
212
+ visible=False,
213
+ )
214
+
215
+ def update_examples(_, text, mic, user_prompt, enable_asr):
216
+ chatbot = []
217
+ history = []
218
+ chatbot, history, error = add_message(chatbot, history, mic, text)
219
+ if error:
220
+ gr.Warning(error)
221
+ print(f"update_examples error")
222
+ return chatbot, history
223
+ else:
224
+ chatbot, history = predict(chatbot, history, tts_model, user_prompt, enable_asr)
225
+ print(f"update_examples done")
226
+ return chatbot, history
227
+
228
+ with gr.Row():
229
+ gr.Examples(
230
+ fn=update_examples,
231
+ examples=CHAT_EXAMPLES,
232
+ inputs=[example_comment, example_text, example_audio, user_prompt, enable_asr],
233
+ outputs=[chatbot, history],
234
+ run_on_click=True,
235
  )
236
+
237
+ def on_submit(chatbot, history, mic, text, user_prompt, enable_asr):
238
+ chatbot, history, error = add_message(chatbot, history, mic, text)
239
  if error:
240
+ gr.Warning(error)
241
  return chatbot, history, None, None
242
  else:
243
+ chatbot, history = predict(chatbot, history, tts_model, user_prompt, enable_asr)
244
  return chatbot, history, None, None
245
 
246
+ gen_btn.click(
247
  fn=on_submit,
248
+ inputs=[chatbot, history, mic, text, user_prompt, enable_asr],
249
  outputs=[chatbot, history, mic, text],
250
+ show_progress=True,
 
251
  )
252
+
253
+ def clear_chat_history():
254
+ return [], []
255
+
256
+ clear_btn.click(
257
+ clear_chat_history,
258
  outputs=[chatbot, history],
259
  show_progress=True,
260
  )
261
 
262
+ def regenerate(chatbot, history, user_prompt, enable_asr):
263
  while chatbot and chatbot[-1]["role"] == "assistant":
264
  chatbot.pop()
265
  while history and history[-1]["role"] == "assistant":
266
  print(f"discard {history[-1]}")
267
  history.pop()
268
+ return predict(chatbot, history, tts_model, user_prompt, enable_asr)
269
 
270
  regen_btn.click(
271
  regenerate,
272
+ [chatbot, history, user_prompt, enable_asr],
273
  [chatbot, history],
274
  show_progress=True,
 
275
  )
276
 
277
  demo.queue().launch(
 
286
  import os
287
 
288
  parser = ArgumentParser()
289
+ parser.add_argument("--model-path", type=str, help="Tokenizer and TTS model path.")
290
  parser.add_argument(
291
  "--server-port", type=int, default=7860, help="Demo server port."
292
  )
 
294
  "--server-name", type=str, default="0.0.0.0", help="Demo server name."
295
  )
296
  args = parser.parse_args()
297
+ tokenizer = StepAudioTokenizer(
298
+ os.path.join(args.model_path, "Step-Audio-Tokenizer")
299
+ )
300
+ tts_model = StepAudioTTS(
301
+ os.path.join(args.model_path, "Step-Audio-TTS-3B"), tokenizer
302
  )
303
+ _launch_demo(args, tts_model)
 
assets/request_rap_ZH.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f4507403248696260e6afb51459a39b6f8f413f2dd9f3f12568547928494c04
3
+ size 229476
assets/tell_me_a_short_story_EN.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66328ecbf0b612977e4ea7661c5b7707e7183d7500e1482fc7b6e012604903a5
3
+ size 80144
assets/yuewen.jpeg ADDED

Git LFS Details

  • SHA256: a10d8611679dcf7c1341a46b4bb041927ef4832c20155385d653bc644e6b2f4c
  • Pointer size: 130 Bytes
  • Size of remote file: 57.7 kB
speakers/{闫雨婷RAP_prompt.wav → TingtingRAP_prompt.wav} RENAMED
File without changes
speakers/{闫雨婷VOCAL_prompt.wav → TingtingVOCAL_prompt.wav} RENAMED
File without changes
speakers/{闫雨婷_prompt.wav → Tingting_prompt.wav} RENAMED
File without changes
speakers/speakers_info.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "闫雨婷RAP": "(RAP)远远甩开的笑他是陆行龟 他曾跌倒也曾吃过灰 他说有福的人才会多吃亏 他的爸爸让他小心交友可他偏偏钻进个垃圾堆 他说他明白How to play",
3
- "闫雨婷VOCAL": "(哼唱)你从一座叫 我 的小镇经过 刚好屋顶的雪化成雨飘落",
4
- "闫雨婷": "那等我们到海洋馆之后,给妈妈买个礼物,好不好呀?"
5
  }
 
1
  {
2
+ "TingtingRAP": "(RAP)远远甩开的笑他是陆行龟 他曾跌倒也曾吃过灰 他说有福的人才会多吃亏 他的爸爸让他小心交友可他偏偏钻进个垃圾堆 他说他明白How to play",
3
+ "TingtingVOCAL": "(哼唱)你从一座叫 我 的小镇经过 刚好屋顶的雪化成雨飘落",
4
+ "Tingting": "那等我们到海洋馆之后,给妈妈买个礼物,好不好呀?"
5
  }
start_app.sh CHANGED
@@ -18,7 +18,6 @@ BASE_REPO_URL="https://${HF_USER_NAME}:${HF_USER_TOKEN}@huggingface.co/stepfun-a
18
  REPOSITORIES=(
19
  "Step-Audio-Tokenizer"
20
  "Step-Audio-TTS-3B"
21
- "Step-Audio-Chat"
22
  )
23
 
24
  # 定义本地存放仓库的目录,默认为当前目录
@@ -55,5 +54,5 @@ for repo in "${REPOSITORIES[@]}"; do
55
  done
56
 
57
  echo "所有仓库已成功下载!"
58
-
59
  python app.py --model $LOCAL_DIR
 
18
  REPOSITORIES=(
19
  "Step-Audio-Tokenizer"
20
  "Step-Audio-TTS-3B"
 
21
  )
22
 
23
  # 定义本地存放仓库的目录,默认为当前目录
 
54
  done
55
 
56
  echo "所有仓库已成功下载!"
57
+ export LD_LIBRARY_PATH=/usr/local/lib/python3.10/dist-packages/nvidia/cuda_nvrtc/lib:$LD_LIBRARY_PATH
58
  python app.py --model $LOCAL_DIR
stepaudio.py DELETED
@@ -1,95 +0,0 @@
1
- import torch
2
- import torchaudio
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
-
5
- from tokenizer import StepAudioTokenizer
6
- from tts import StepAudioTTS
7
- from utils import load_audio, speech_adjust, volumn_adjust
8
-
9
-
10
- class StepAudio:
11
- def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str):
12
- self.llm_tokenizer = AutoTokenizer.from_pretrained(
13
- llm_path, trust_remote_code=True
14
- )
15
- self.encoder = StepAudioTokenizer(tokenizer_path)
16
- self.decoder = StepAudioTTS(tts_path, self.encoder)
17
- self.llm = AutoModelForCausalLM.from_pretrained(
18
- llm_path,
19
- torch_dtype=torch.bfloat16,
20
- device_map="auto",
21
- trust_remote_code=True,
22
- )
23
-
24
- def __call__(
25
- self,
26
- messages: list,
27
- speaker_id: str,
28
- speed_ratio: float = 1.0,
29
- volumn_ratio: float = 1.0,
30
- ):
31
- text_with_audio = self.apply_chat_template(messages)
32
- token_ids = self.llm_tokenizer.encode(text_with_audio, return_tensors="pt")
33
- outputs = self.llm.generate(
34
- token_ids, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True
35
- )
36
- output_token_ids = outputs[:, token_ids.shape[-1] : -1].tolist()[0]
37
- output_text = self.llm_tokenizer.decode(output_token_ids)
38
- output_audio, sr = self.decoder(output_text, speaker_id)
39
- if speed_ratio != 1.0:
40
- output_audio = speech_adjust(output_audio, sr, speed_ratio)
41
- if volumn_ratio != 1.0:
42
- output_audio = volumn_adjust(output_audio, volumn_ratio)
43
- return output_text, output_audio, sr
44
-
45
- def encode_audio(self, audio_path):
46
- audio_wav, sr = load_audio(audio_path)
47
- audio_tokens = self.encoder(audio_wav, sr)
48
- return audio_tokens
49
-
50
- def apply_chat_template(self, messages: list):
51
- text_with_audio = ""
52
- for msg in messages:
53
- role = msg["role"]
54
- content = msg["content"]
55
- if role == "user":
56
- role = "human"
57
- if isinstance(content, str):
58
- text_with_audio += f"<|BOT|>{role}\n{content}<|EOT|>"
59
- elif isinstance(content, dict):
60
- if content["type"] == "text":
61
- text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>"
62
- elif content["type"] == "audio":
63
- audio_tokens = self.encode_audio(content["audio"])
64
- text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>"
65
- elif content is None:
66
- text_with_audio += f"<|BOT|>{role}\n"
67
- else:
68
- raise ValueError(f"Unsupported content type: {type(content)}")
69
- if not text_with_audio.endswith("<|BOT|>assistant\n"):
70
- text_with_audio += "<|BOT|>assistant\n"
71
- return text_with_audio
72
-
73
-
74
- if __name__ == "__main__":
75
- model = StepAudio(
76
- encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder",
77
- decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder",
78
- llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18",
79
- )
80
-
81
- text, audio, sr = model(
82
- [{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}],
83
- "闫雨婷",
84
- )
85
- torchaudio.save("output/output_e2e_tqta.wav", audio, sr)
86
- text, audio, sr = model(
87
- [
88
- {
89
- "role": "user",
90
- "content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"},
91
- }
92
- ],
93
- "闫雨婷",
94
- )
95
- torchaudio.save("output/output_e2e_aqta.wav", audio, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts.py CHANGED
@@ -37,6 +37,19 @@ class StepAudioTTS:
37
  model_path,
38
  encoder,
39
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  self.llm = AutoModelForCausalLM.from_pretrained(
41
  model_path,
42
  torch_dtype=torch.bfloat16,
 
37
  model_path,
38
  encoder,
39
  ):
40
+ # load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib`
41
+ # if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib
42
+ try:
43
+ if torch.__version__ >= "2.5":
44
+ torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'))
45
+ elif torch.__version__ >= "2.3":
46
+ torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so'))
47
+ elif torch.__version__ >= "2.2":
48
+ torch.ops.load_library(os.path.join(model_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so'))
49
+ print("Load optimus_ths successfully and flash attn would be enabled")
50
+ except Exception as err:
51
+ print(f"Fail to load optimus_ths and flash attn is disabled: {err}")
52
+
53
  self.llm = AutoModelForCausalLM.from_pretrained(
54
  model_path,
55
  torch_dtype=torch.bfloat16,
yuewen_api.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI, APIStatusError
2
+ import os
3
+
4
+ AUDIO_CHAT_MODEL = os.getenv("STEP_AUDIO_CHAT", "step-1o-audio")
5
+ TTS_MODEL = os.getenv("STEP_AUDIO_TTS", "step-tts-mini")
6
+ ASR_MODEL = os.getenv("STEP_AUDIO_ASR", "step-asr")
7
+ STEP_BASE_URL = os.getenv("STEP_BASE_URL", "https://api.stepfun.com/v1")
8
+
9
+ client = OpenAI(
10
+ base_url=STEP_BASE_URL,
11
+ )
12
+
13
+
14
+ def call_audiochat(messages):
15
+ try:
16
+ completion = client.chat.completions.create(
17
+ model=AUDIO_CHAT_MODEL,
18
+ messages=messages,
19
+ presence_penalty=1,
20
+ )
21
+ return completion.choices[0].message.content
22
+ except APIStatusError as e:
23
+ print(e)
24
+ raise RuntimeError(e)
25
+ except Exception as e:
26
+ raise e
27
+
28
+
29
+ def call_tts(text, audio_path, voice="qinqienvsheng"):
30
+ response = client.audio.speech.create(model=TTS_MODEL, voice=voice, input=text)
31
+ response.stream_to_file(audio_path)
32
+ return True
33
+
34
+
35
+ def call_asr(audio_path):
36
+ with open(audio_path, "rb") as audio_file:
37
+ response = client.audio.transcriptions.create(
38
+ model=ASR_MODEL,
39
+ file=audio_file,
40
+ response_format="json",
41
+ )
42
+ return response.text
43
+
44
+ if __name__ == "__main__":
45
+ messages = [{"role":"user", "content": "介绍下你自己"}]
46
+ res = call_audiochat(messages)
47
+ print("call audiochat: ", res)
48
+
49
+ audio_path = "test.mp3"
50
+ text = "hello, 阶跃"
51
+ call_tts(text, audio_path)
52
+
53
+ res = call_asr(audio_path)
54
+ print("call asr:", res)