Steveeeeeeen HF Staff commited on
Commit
0a7cf4d
Β·
verified Β·
1 Parent(s): 876f2fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -23
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py
2
-
3
  import os
4
  import shlex
5
  import subprocess
@@ -82,10 +80,12 @@ def _get_models(model_path: str):
82
  # Inference
83
  # -----------------------
84
  @spaces.GPU
85
- def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mini"):
86
  """
87
  Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
88
  Heavy models are created via _get_models() inside this process.
 
 
89
  """
90
  try:
91
  audio_model, token2wav = _get_models(model_path)
@@ -101,12 +101,13 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mi
101
  max_new_tokens=4096,
102
  temperature=0.7,
103
  repetition_penalty=1.05,
104
- do_sample=True
105
  )
106
  print(f"predict text={text!r}")
107
 
108
- # Convert tokens -> waveform bytes using token2wav
109
- audio_bytes = token2wav(audio_tokens, prompt_wav)
 
110
 
111
  # Persist to temp .wav for the UI
112
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
@@ -118,7 +119,7 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mi
118
 
119
  except Exception:
120
  print(traceback.format_exc())
121
- gr.Warning("Some error happend, please try again.")
122
 
123
  return chatbot, history
124
 
@@ -152,6 +153,9 @@ def _launch_demo(args):
152
  # Initialize history with current system prompt value
153
  history = gr.State([{"role": "system", "content": system_prompt.value}])
154
 
 
 
 
155
  mic = gr.Audio(type="filepath", label="🎀 Speak (optional)")
156
  text = gr.Textbox(placeholder="Enter message ...", label="πŸ’¬ Text")
157
 
@@ -160,37 +164,47 @@ def _launch_demo(args):
160
  regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
161
  submit_btn = gr.Button("πŸš€ Submit")
162
 
163
- def on_submit(chatbot_val, history_val, mic_val, text_val):
164
  chatbot2, history2, error = add_message(chatbot_val, history_val, mic_val, text_val)
165
  if error:
166
  gr.Warning(error)
167
- return chatbot2, history2, None, None
168
- # Run GPU inference with only picklable args
 
 
 
 
169
  chatbot2, history2 = predict(
170
  chatbot2, history2,
171
- args.prompt_wav, args.cache_dir,
172
- model_path=args.model_path
 
173
  )
174
- return chatbot2, history2, None, None
 
 
 
175
 
176
  submit_btn.click(
177
  fn=on_submit,
178
- inputs=[chatbot, history, mic, text],
179
- outputs=[chatbot, history, mic, text],
180
  concurrency_limit=4,
181
  concurrency_id="gpu_queue",
182
  )
183
 
184
- def on_clean(system_prompt_text):
185
- return reset_state(system_prompt_text)
 
 
186
 
187
  clean_btn.click(
188
  fn=on_clean,
189
- inputs=[system_prompt],
190
- outputs=[chatbot, history],
191
  )
192
 
193
- def on_regenerate(chatbot_val, history_val):
194
  # Drop last assistant turn(s) to regenerate
195
  while chatbot_val and chatbot_val[-1]["role"] == "assistant":
196
  chatbot_val.pop()
@@ -199,13 +213,14 @@ def _launch_demo(args):
199
  history_val.pop()
200
  return predict(
201
  chatbot_val, history_val,
202
- args.prompt_wav, args.cache_dir,
203
- model_path=args.model_path
 
204
  )
205
 
206
  regen_btn.click(
207
  fn=on_regenerate,
208
- inputs=[chatbot, history],
209
  outputs=[chatbot, history],
210
  concurrency_id="gpu_queue",
211
  )
 
 
 
1
  import os
2
  import shlex
3
  import subprocess
 
80
  # Inference
81
  # -----------------------
82
  @spaces.GPU
83
+ def predict(chatbot, history, prompt_wav_path, cache_dir, model_path="Step-Audio-2-mini"):
84
  """
85
  Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
86
  Heavy models are created via _get_models() inside this process.
87
+
88
+ `prompt_wav_path` is the CURRENT reference audio to condition on (can be user upload).
89
  """
90
  try:
91
  audio_model, token2wav = _get_models(model_path)
 
101
  max_new_tokens=4096,
102
  temperature=0.7,
103
  repetition_penalty=1.05,
104
+ do_sample=True,
105
  )
106
  print(f"predict text={text!r}")
107
 
108
+ # Convert tokens -> waveform bytes using token2wav with the *selected* prompt
109
+ prompt_path = prompt_wav_path if (prompt_wav_path and Path(prompt_wav_path).exists()) else None
110
+ audio_bytes = token2wav(audio_tokens, prompt_path)
111
 
112
  # Persist to temp .wav for the UI
113
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
 
119
 
120
  except Exception:
121
  print(traceback.format_exc())
122
+ gr.Warning("Some error happened, please try again.")
123
 
124
  return chatbot, history
125
 
 
153
  # Initialize history with current system prompt value
154
  history = gr.State([{"role": "system", "content": system_prompt.value}])
155
 
156
+ # NEW: keep track of the *current* prompt wav path (defaults to bundled voice)
157
+ current_prompt_wav = gr.State(args.prompt_wav)
158
+
159
  mic = gr.Audio(type="filepath", label="🎀 Speak (optional)")
160
  text = gr.Textbox(placeholder="Enter message ...", label="πŸ’¬ Text")
161
 
 
164
  regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
165
  submit_btn = gr.Button("πŸš€ Submit")
166
 
167
+ def on_submit(chatbot_val, history_val, mic_val, text_val, current_prompt):
168
  chatbot2, history2, error = add_message(chatbot_val, history_val, mic_val, text_val)
169
  if error:
170
  gr.Warning(error)
171
+ # keep state intact
172
+ return chatbot2, history2, None, None, current_prompt
173
+
174
+ # Choose prompt: prefer latest user mic if present, else stick to remembered prompt
175
+ prompt_path = mic_val if (mic_val and Path(mic_val).exists()) else current_prompt
176
+
177
  chatbot2, history2 = predict(
178
  chatbot2, history2,
179
+ prompt_path,
180
+ args.cache_dir,
181
+ model_path=args.model_path,
182
  )
183
+
184
+ # Clear inputs; remember the prompt we actually used
185
+ new_prompt_state = prompt_path
186
+ return chatbot2, history2, None, None, new_prompt_state
187
 
188
  submit_btn.click(
189
  fn=on_submit,
190
+ inputs=[chatbot, history, mic, text, current_prompt_wav],
191
+ outputs=[chatbot, history, mic, text, current_prompt_wav],
192
  concurrency_limit=4,
193
  concurrency_id="gpu_queue",
194
  )
195
 
196
+ def on_clean(system_prompt_text, _default_prompt):
197
+ # Reset chat and also reset the remembered prompt back to default
198
+ new_chatbot, new_history = reset_state(system_prompt_text)
199
+ return new_chatbot, new_history, _default_prompt
200
 
201
  clean_btn.click(
202
  fn=on_clean,
203
+ inputs=[system_prompt, current_prompt_wav],
204
+ outputs=[chatbot, history, current_prompt_wav],
205
  )
206
 
207
+ def on_regenerate(chatbot_val, history_val, current_prompt):
208
  # Drop last assistant turn(s) to regenerate
209
  while chatbot_val and chatbot_val[-1]["role"] == "assistant":
210
  chatbot_val.pop()
 
213
  history_val.pop()
214
  return predict(
215
  chatbot_val, history_val,
216
+ current_prompt, # use the remembered prompt for regen
217
+ args.cache_dir,
218
+ model_path=args.model_path,
219
  )
220
 
221
  regen_btn.click(
222
  fn=on_regenerate,
223
+ inputs=[chatbot, history, current_prompt_wav],
224
  outputs=[chatbot, history],
225
  concurrency_id="gpu_queue",
226
  )