Steveeeeeeen HF Staff commited on
Commit
a9a5df1
·
verified ·
1 Parent(s): 6daeb9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -69
app.py CHANGED
@@ -1,31 +1,37 @@
1
  import os
2
  import shlex
3
  import subprocess
4
- import threading
5
  import tempfile
6
  import traceback
7
  from pathlib import Path
8
 
 
9
  os.system("pip install -r requirements.txt")
 
 
10
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/campplus.onnx -P token2wav")
11
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.pt -P token2wav")
12
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.yaml -P token2wav")
13
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/hift.pt -P token2wav")
14
 
15
- # HF token passthrough
16
  hf_token = os.getenv("HF_TOKEN", None)
17
- os.environ["HF_TOKEN"] = hf_token
 
18
 
19
  import spaces
20
  import gradio as gr
21
 
22
- def save_tmp_audio(audio_bytes, cache_dir):
 
23
  os.makedirs(cache_dir, exist_ok=True)
24
  with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, suffix=".wav") as temp_audio:
25
  temp_audio.write(audio_bytes)
26
- return temp_audio.name
 
27
 
28
  def add_message(chatbot, history, mic, text):
 
29
  if not mic and not text:
30
  return chatbot, history, "Input is empty"
31
 
@@ -36,49 +42,47 @@ def add_message(chatbot, history, mic, text):
36
  chatbot.append({"role": "user", "content": {"path": mic}})
37
  history.append({"role": "human", "content": [{"type": "audio", "audio": mic}]})
38
 
 
39
  return chatbot, history, None
40
 
41
- def reset_state(system_prompt):
 
 
42
  return [], [{"role": "system", "content": system_prompt}]
43
 
44
- _AUDIO_MODEL = None
45
- _TOKEN2WAV = None
46
- _INIT_LOCK = threading.Lock()
47
 
48
- def _ensure_models(model_path: str, token2wav_dir: str):
 
 
 
49
  """
50
- Create heavy, non-picklable objects *inside* the worker process exactly once.
 
51
  """
52
- global _AUDIO_MODEL, _TOKEN2WAV
53
- if _AUDIO_MODEL is None or _TOKEN2WAV is None:
54
- with _INIT_LOCK:
55
- if _AUDIO_MODEL is None or _TOKEN2WAV is None:
56
- # Import here to avoid importing before process fork
57
- from stepaudio2 import StepAudio2
58
- from token2wav import Token2wav
59
- # Create non-picklable instances
60
- _AUDIO_MODEL = StepAudio2(model_path)
61
- _TOKEN2WAV = Token2wav(token2wav_dir)
62
-
63
- return _AUDIO_MODEL, _TOKEN2WAV
64
 
65
  @spaces.GPU
66
- def predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir):
67
  """
68
- IMPORTANT: All parameters are simple strings/lists (picklable).
69
- Heavy objects are created inside via _ensure_models(...).
70
  """
71
  try:
72
- audio_model, token2wav = _ensure_models(model_path, token2wav_dir)
73
 
74
- # Stream start marker
75
  history.append({
76
  "role": "assistant",
77
  "content": [{"type": "text", "text": "<tts_start>"}],
78
  "eot": False
79
  })
80
 
81
- # Your original generation call
82
  tokens, text, audio_tokens = audio_model(
83
  history,
84
  max_new_tokens=4096,
@@ -86,21 +90,23 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir):
86
  repetition_penalty=1.05,
87
  do_sample=True
88
  )
 
89
 
90
- # Convert tokens -> wav bytes
91
  audio_bytes = token2wav(audio_tokens, prompt_wav)
92
 
93
- # Save to temp file for gradio Chatbot
94
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
95
- chatbot.append({"role": "assistant", "content": {"path": audio_path}})
96
 
97
- # Finish the assistant turn
 
98
  history[-1]["content"].append({"type": "token", "token": tokens})
99
  history[-1]["eot"] = True
100
 
101
  except Exception:
102
  print(traceback.format_exc())
103
- gr.Warning("Some error happened, please try again.")
 
104
  return chatbot, history
105
 
106
  def _launch_demo(args):
@@ -112,71 +118,85 @@ def _launch_demo(args):
112
  label="System Prompt",
113
  value=(
114
  "你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n"
115
- "你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n"
 
116
  "今天是2025年8月29日,星期五\n"
117
  "请用默认女声与用户交流。"
118
  ),
119
- lines=2
120
  )
121
 
122
- chatbot = gr.Chatbot(elem_id="chatbot", min_height=800, type="messages")
123
- # Initialize history with the *string* value of the prompt
124
- history = gr.State([{"role": "system", "content": system_prompt.value}])
 
 
125
 
126
- # Inputs
127
- mic = gr.Audio(type="filepath")
128
- text = gr.Textbox(placeholder="Enter message ...")
129
 
130
- # Serializable configuration inputs (STRINGS ONLY)
131
- model_path = "Step-Audio-2-mini"
132
- token2wav_dir = "token2wav"
133
- prompt_wav = "assets/default_female.wav"
134
- cache_dir = "/tmp/stepaudio2"
135
 
136
  with gr.Row():
137
  clean_btn = gr.Button("🧹 Clear History (清除历史)")
138
  regen_btn = gr.Button("🤔️ Regenerate (重试)")
139
  submit_btn = gr.Button("🚀 Submit")
140
 
141
- # --- event functions (now only use serializable args) ---
142
- def on_submit(chatbot, history, mic, text, prompt_wav, cache_dir, model_path, token2wav_dir):
143
- chatbot, history, error = add_message(chatbot, history, mic, text)
144
  if error:
145
  gr.Warning(error)
146
- return chatbot, history, None, None
147
- chatbot, history = predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir)
148
- return chatbot, history, None, None
 
 
 
 
 
149
 
150
  submit_btn.click(
151
  fn=on_submit,
152
- inputs=[chatbot, history, mic, text, prompt_wav, cache_dir, model_path, token2wav_dir],
153
  outputs=[chatbot, history, mic, text],
154
  concurrency_limit=4,
155
  concurrency_id="gpu_queue",
156
  )
157
 
 
 
 
158
  clean_btn.click(
159
- fn=reset_state,
160
  inputs=[system_prompt],
161
  outputs=[chatbot, history],
162
  )
163
 
164
- def on_regen(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir):
165
- # drop last assistant turn so we can re-run
166
- while chatbot and chatbot[-1]["role"] == "assistant":
167
- chatbot.pop()
168
- while history and history[-1]["role"] == "assistant":
169
- history.pop()
170
- return predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir)
 
 
 
 
 
171
 
172
  regen_btn.click(
173
- fn=on_regen,
174
- inputs=[chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir],
175
  outputs=[chatbot, history],
176
  concurrency_id="gpu_queue",
177
  )
178
 
179
- demo.queue().launch(server_port=args.server_port, server_name=args.server_name)
 
 
 
180
 
181
  if __name__ == "__main__":
182
  from argparse import ArgumentParser
@@ -190,8 +210,4 @@ if __name__ == "__main__":
190
  args = parser.parse_args()
191
 
192
  os.environ["GRADIO_TEMP_DIR"] = args.cache_dir
193
- os.makedirs(args.cache_dir, exist_ok=True)
194
-
195
- # NOTE: Do NOT instantiate heavy models here.
196
- # They will be created lazily inside predict() via _ensure_models(...).
197
  _launch_demo(args)
 
1
  import os
2
  import shlex
3
  import subprocess
 
4
  import tempfile
5
  import traceback
6
  from pathlib import Path
7
 
8
+ # --- Install / fetch runtime deps & assets ---
9
  os.system("pip install -r requirements.txt")
10
+
11
+ # Download token2wav assets
12
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/campplus.onnx -P token2wav")
13
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.pt -P token2wav")
14
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.yaml -P token2wav")
15
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/hift.pt -P token2wav")
16
 
17
+ # Hugging Face token (optional)
18
  hf_token = os.getenv("HF_TOKEN", None)
19
+ if hf_token is not None:
20
+ os.environ["HF_TOKEN"] = hf_token
21
 
22
  import spaces
23
  import gradio as gr
24
 
25
+ def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
26
+ """Save raw wav bytes to a temporary file and return path."""
27
  os.makedirs(cache_dir, exist_ok=True)
28
  with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, suffix=".wav") as temp_audio:
29
  temp_audio.write(audio_bytes)
30
+ return temp_audio.name
31
+
32
 
33
  def add_message(chatbot, history, mic, text):
34
+ """Append user text or audio to the chat + history."""
35
  if not mic and not text:
36
  return chatbot, history, "Input is empty"
37
 
 
42
  chatbot.append({"role": "user", "content": {"path": mic}})
43
  history.append({"role": "human", "content": [{"type": "audio", "audio": mic}]})
44
 
45
+ print(f"{history=}")
46
  return chatbot, history, None
47
 
48
+
49
+ def reset_state(system_prompt: str):
50
+ """Reset chat to a single system message."""
51
  return [], [{"role": "system", "content": system_prompt}]
52
 
 
 
 
53
 
54
+ _MODEL = None
55
+ _TOK2WAV = None
56
+
57
+ def _get_models(model_path: str):
58
  """
59
+ Lazily load heavy, non-picklable models INSIDE the worker process
60
+ and cache them in module globals for reuse.
61
  """
62
+ global _MODEL, _TOK2WAV
63
+ if _MODEL is None or _TOK2WAV is None:
64
+ # Import here so the objects are constructed in the worker
65
+ from stepaudio2 import StepAudio2
66
+ from token2wav import Token2wav
67
+ _MODEL = StepAudio2(model_path)
68
+ _TOK2WAV = Token2wav("token2wav")
69
+ return _MODEL, _TOK2WAV
 
 
 
 
70
 
71
  @spaces.GPU
72
+ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mini"):
73
  """
74
+ Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
75
+ Heavy models are created via _get_models() inside this process.
76
  """
77
  try:
78
+ audio_model, token2wav = _get_models(model_path)
79
 
 
80
  history.append({
81
  "role": "assistant",
82
  "content": [{"type": "text", "text": "<tts_start>"}],
83
  "eot": False
84
  })
85
 
 
86
  tokens, text, audio_tokens = audio_model(
87
  history,
88
  max_new_tokens=4096,
 
90
  repetition_penalty=1.05,
91
  do_sample=True
92
  )
93
+ print(f"predict text={text!r}")
94
 
95
+ # Convert tokens -> waveform bytes using token2wav
96
  audio_bytes = token2wav(audio_tokens, prompt_wav)
97
 
98
+ # Persist to temp .wav for the UI
99
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
 
100
 
101
+ # Append assistant audio message
102
+ chatbot.append({"role": "assistant", "content": {"path": audio_path}})
103
  history[-1]["content"].append({"type": "token", "token": tokens})
104
  history[-1]["eot"] = True
105
 
106
  except Exception:
107
  print(traceback.format_exc())
108
+ gr.Warning("Some error happend, please try again.")
109
+
110
  return chatbot, history
111
 
112
  def _launch_demo(args):
 
118
  label="System Prompt",
119
  value=(
120
  "你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n"
121
+ "你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,"
122
+ "说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n"
123
  "今天是2025年8月29日,星期五\n"
124
  "请用默认女声与用户交流。"
125
  ),
126
+ lines=2,
127
  )
128
 
129
+ chatbot = gr.Chatbot(
130
+ elem_id="chatbot",
131
+ min_height=800,
132
+ type="messages",
133
+ )
134
 
135
+ # Initialize history with current system prompt value
136
+ history = gr.State([{"role": "system", "content": system_prompt.value}])
 
137
 
138
+ mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
139
+ text = gr.Textbox(placeholder="Enter message ...", label="💬 Text")
 
 
 
140
 
141
  with gr.Row():
142
  clean_btn = gr.Button("🧹 Clear History (清除历史)")
143
  regen_btn = gr.Button("🤔️ Regenerate (重试)")
144
  submit_btn = gr.Button("🚀 Submit")
145
 
146
+ def on_submit(chatbot_val, history_val, mic_val, text_val):
147
+ chatbot2, history2, error = add_message(chatbot_val, history_val, mic_val, text_val)
 
148
  if error:
149
  gr.Warning(error)
150
+ return chatbot2, history2, None, None
151
+ # Run GPU inference with only picklable args
152
+ chatbot2, history2 = predict(
153
+ chatbot2, history2,
154
+ args.prompt_wav, args.cache_dir,
155
+ model_path=args.model_path
156
+ )
157
+ return chatbot2, history2, None, None
158
 
159
  submit_btn.click(
160
  fn=on_submit,
161
+ inputs=[chatbot, history, mic, text],
162
  outputs=[chatbot, history, mic, text],
163
  concurrency_limit=4,
164
  concurrency_id="gpu_queue",
165
  )
166
 
167
+ def on_clean(system_prompt_text):
168
+ return reset_state(system_prompt_text)
169
+
170
  clean_btn.click(
171
+ fn=on_clean,
172
  inputs=[system_prompt],
173
  outputs=[chatbot, history],
174
  )
175
 
176
+ def on_regenerate(chatbot_val, history_val):
177
+ # Drop last assistant turn(s) to regenerate
178
+ while chatbot_val and chatbot_val[-1]["role"] == "assistant":
179
+ chatbot_val.pop()
180
+ while history_val and history_val[-1]["role"] == "assistant":
181
+ print(f"discard {history_val[-1]}")
182
+ history_val.pop()
183
+ return predict(
184
+ chatbot_val, history_val,
185
+ args.prompt_wav, args.cache_dir,
186
+ model_path=args.model_path
187
+ )
188
 
189
  regen_btn.click(
190
+ fn=on_regenerate,
191
+ inputs=[chatbot, history],
192
  outputs=[chatbot, history],
193
  concurrency_id="gpu_queue",
194
  )
195
 
196
+ demo.queue().launch(
197
+ server_port=args.server_port,
198
+ server_name=args.server_name,
199
+ )
200
 
201
  if __name__ == "__main__":
202
  from argparse import ArgumentParser
 
210
  args = parser.parse_args()
211
 
212
  os.environ["GRADIO_TEMP_DIR"] = args.cache_dir
 
 
 
 
213
  _launch_demo(args)