Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,346 Bytes
e4cc20f 676ffac e4cc20f a9a5df1 e4cc20f a9a5df1 e8f2ced b58e48b a9a5df1 6df37bd a9a5df1 6df37bd 7f2c936 4d768c3 876f2fc a9a5df1 676ffac a9a5df1 3cf0e6f a9a5df1 3cf0e6f 676ffac 3cf0e6f a9a5df1 3cf0e6f a9a5df1 3cf0e6f 676ffac 876f2fc a9a5df1 676ffac a9a5df1 676ffac a9a5df1 876f2fc a9a5df1 676ffac 876f2fc 3cf0e6f 0a7cf4d 676ffac a9a5df1 0a7cf4d 676ffac 3cf0e6f a9a5df1 676ffac 0a7cf4d 676ffac a9a5df1 676ffac 0a7cf4d 676ffac a9a5df1 676ffac a9a5df1 3cf0e6f 676ffac 3cf0e6f 0a7cf4d a9a5df1 3cf0e6f 876f2fc 676ffac 3cf0e6f 676ffac 3cf0e6f 676ffac a9a5df1 676ffac a9a5df1 3cf0e6f 676ffac a9a5df1 676ffac 876f2fc a9a5df1 3cf0e6f 0a7cf4d a9a5df1 676ffac 3cf0e6f 0a7cf4d a9a5df1 3cf0e6f 676ffac 0a7cf4d a9a5df1 0a7cf4d a9a5df1 0a7cf4d 3cf0e6f 0a7cf4d 3cf0e6f 0a7cf4d a9a5df1 3cf0e6f a9a5df1 0a7cf4d 3cf0e6f 0a7cf4d 876f2fc a9a5df1 0a7cf4d a9a5df1 3cf0e6f a9a5df1 0a7cf4d 676ffac 3cf0e6f a9a5df1 3cf0e6f 876f2fc 3cf0e6f 676ffac 876f2fc 676ffac 3cf0e6f 676ffac 3cf0e6f b9df098 676ffac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import os
import shlex
import subprocess
import tempfile
import traceback
from pathlib import Path
# --- Install / fetch runtime deps & assets ---
os.system("pip install -r requirements.txt")
# Download token2wav assets
os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/campplus.onnx -P token2wav")
os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.pt -P token2wav")
os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.yaml -P token2wav")
os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/hift.pt -P token2wav")
# Hugging Face token (optional)
hf_token = os.getenv("HF_TOKEN", None)
if hf_token is not None:
os.environ["HF_TOKEN"] = hf_token
import spaces
import gradio as gr
# -----------------------
# Utility helpers
# -----------------------
def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
"""Save raw wav bytes to a temporary file and return path."""
os.makedirs(cache_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
return temp_audio.name
def add_message(chatbot, history, mic, text):
"""Append user text or audio to the chat + history."""
if not mic and not text:
return chatbot, history, "Input is empty"
if text:
chatbot.append({"role": "user", "content": text})
history.append({"role": "human", "content": text})
elif mic and Path(mic).exists():
chatbot.append({"role": "user", "content": {"path": mic}})
history.append({"role": "human", "content": [{"type": "audio", "audio": mic}]})
print(f"{history=}")
return chatbot, history, None
def reset_state(system_prompt: str):
"""Reset chat to a single system message."""
return [], [{"role": "system", "content": system_prompt}]
# -----------------------
# Lazy model loading inside the GPU worker
# -----------------------
_MODEL = None
_TOK2WAV = None
def _get_models(model_path: str):
"""
Lazily load heavy, non-picklable models INSIDE the worker process
and cache them in module globals for reuse.
"""
global _MODEL, _TOK2WAV
if _MODEL is None or _TOK2WAV is None:
# Import here so the objects are constructed in the worker
from stepaudio2 import StepAudio2
from token2wav import Token2wav
_MODEL = StepAudio2(model_path)
_TOK2WAV = Token2wav("token2wav")
return _MODEL, _TOK2WAV
# -----------------------
# Inference
# -----------------------
@spaces.GPU
def predict(chatbot, history, prompt_wav_path, cache_dir, model_path="Step-Audio-2-mini"):
"""
Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
Heavy models are created via _get_models() inside this process.
`prompt_wav_path` is the CURRENT reference audio to condition on (can be user upload).
"""
try:
audio_model, token2wav = _get_models(model_path)
history.append({
"role": "assistant",
"content": [{"type": "text", "text": "<tts_start>"}],
"eot": False
})
tokens, text, audio_tokens = audio_model(
history,
max_new_tokens=4096,
temperature=0.7,
repetition_penalty=1.05,
do_sample=True,
)
print(f"predict text={text!r}")
# Convert tokens -> waveform bytes using token2wav with the *selected* prompt
prompt_path = prompt_wav_path if (prompt_wav_path and Path(prompt_wav_path).exists()) else None
audio_bytes = token2wav(audio_tokens, prompt_path)
# Persist to temp .wav for the UI
audio_path = save_tmp_audio(audio_bytes, cache_dir)
# Append assistant audio message
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
history[-1]["content"].append({"type": "token", "token": tokens})
history[-1]["eot"] = True
except Exception:
print(traceback.format_exc())
gr.Warning("Some error happened, please try again.")
return chatbot, history
# -----------------------
# UI
# -----------------------
def _launch_demo(args):
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
gr.Markdown("""<center><font size=8>Step Audio 2 Demo</font></center>""")
with gr.Row():
system_prompt = gr.Textbox(
label="System Prompt",
value=(
"你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n"
"你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,"
"说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n"
"今天是2025年8月29日,星期五\n"
"请用默认女声与用户交流。"
),
lines=2,
)
chatbot = gr.Chatbot(
elem_id="chatbot",
min_height=800,
type="messages",
)
# Initialize history with current system prompt value
history = gr.State([{"role": "system", "content": system_prompt.value}])
# NEW: keep track of the *current* prompt wav path (defaults to bundled voice)
current_prompt_wav = gr.State(args.prompt_wav)
mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
text = gr.Textbox(placeholder="Enter message ...", label="💬 Text")
with gr.Row():
clean_btn = gr.Button("🧹 Clear History (清除历史)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
submit_btn = gr.Button("🚀 Submit")
def on_submit(chatbot_val, history_val, mic_val, text_val, current_prompt):
chatbot2, history2, error = add_message(chatbot_val, history_val, mic_val, text_val)
if error:
gr.Warning(error)
# keep state intact
return chatbot2, history2, None, None, current_prompt
# Choose prompt: prefer latest user mic if present, else stick to remembered prompt
prompt_path = mic_val if (mic_val and Path(mic_val).exists()) else current_prompt
chatbot2, history2 = predict(
chatbot2, history2,
prompt_path,
args.cache_dir,
model_path=args.model_path,
)
# Clear inputs; remember the prompt we actually used
new_prompt_state = prompt_path
return chatbot2, history2, None, None, new_prompt_state
submit_btn.click(
fn=on_submit,
inputs=[chatbot, history, mic, text, current_prompt_wav],
outputs=[chatbot, history, mic, text, current_prompt_wav],
concurrency_limit=4,
concurrency_id="gpu_queue",
)
def on_clean(system_prompt_text, _default_prompt):
# Reset chat and also reset the remembered prompt back to default
new_chatbot, new_history = reset_state(system_prompt_text)
return new_chatbot, new_history, _default_prompt
clean_btn.click(
fn=on_clean,
inputs=[system_prompt, current_prompt_wav],
outputs=[chatbot, history, current_prompt_wav],
)
def on_regenerate(chatbot_val, history_val, current_prompt):
# Drop last assistant turn(s) to regenerate
while chatbot_val and chatbot_val[-1]["role"] == "assistant":
chatbot_val.pop()
while history_val and history_val[-1]["role"] == "assistant":
print(f"discard {history_val[-1]}")
history_val.pop()
return predict(
chatbot_val, history_val,
current_prompt, # use the remembered prompt for regen
args.cache_dir,
model_path=args.model_path,
)
regen_btn.click(
fn=on_regenerate,
inputs=[chatbot, history, current_prompt_wav],
outputs=[chatbot, history],
concurrency_id="gpu_queue",
)
demo.queue().launch(
server_port=args.server_port,
server_name=args.server_name,
)
# -----------------------
# Entrypoint
# -----------------------
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
parser.add_argument("--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant.")
parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
args = parser.parse_args()
os.environ["GRADIO_TEMP_DIR"] = args.cache_dir
Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
_launch_demo(args)
|