Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,5 +1,3 @@ | |
| 1 | 
            -
            # app.py
         | 
| 2 | 
            -
             | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            import shlex
         | 
| 5 | 
             
            import subprocess
         | 
| @@ -82,10 +80,12 @@ def _get_models(model_path: str): | |
| 82 | 
             
            # Inference
         | 
| 83 | 
             
            # -----------------------
         | 
| 84 | 
             
            @spaces.GPU
         | 
| 85 | 
            -
            def predict(chatbot, history,  | 
| 86 | 
             
                """
         | 
| 87 | 
             
                Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
         | 
| 88 | 
             
                Heavy models are created via _get_models() inside this process.
         | 
|  | |
|  | |
| 89 | 
             
                """
         | 
| 90 | 
             
                try:
         | 
| 91 | 
             
                    audio_model, token2wav = _get_models(model_path)
         | 
| @@ -101,12 +101,13 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mi | |
| 101 | 
             
                        max_new_tokens=4096,
         | 
| 102 | 
             
                        temperature=0.7,
         | 
| 103 | 
             
                        repetition_penalty=1.05,
         | 
| 104 | 
            -
                        do_sample=True
         | 
| 105 | 
             
                    )
         | 
| 106 | 
             
                    print(f"predict text={text!r}")
         | 
| 107 |  | 
| 108 | 
            -
                    # Convert tokens -> waveform bytes using token2wav
         | 
| 109 | 
            -
                     | 
|  | |
| 110 |  | 
| 111 | 
             
                    # Persist to temp .wav for the UI
         | 
| 112 | 
             
                    audio_path = save_tmp_audio(audio_bytes, cache_dir)
         | 
| @@ -118,7 +119,7 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mi | |
| 118 |  | 
| 119 | 
             
                except Exception:
         | 
| 120 | 
             
                    print(traceback.format_exc())
         | 
| 121 | 
            -
                    gr.Warning("Some error  | 
| 122 |  | 
| 123 | 
             
                return chatbot, history
         | 
| 124 |  | 
| @@ -152,6 +153,9 @@ def _launch_demo(args): | |
| 152 | 
             
                    # Initialize history with current system prompt value
         | 
| 153 | 
             
                    history = gr.State([{"role": "system", "content": system_prompt.value}])
         | 
| 154 |  | 
|  | |
|  | |
|  | |
| 155 | 
             
                    mic = gr.Audio(type="filepath", label="π€ Speak (optional)")
         | 
| 156 | 
             
                    text = gr.Textbox(placeholder="Enter message ...", label="π¬ Text")
         | 
| 157 |  | 
| @@ -160,37 +164,47 @@ def _launch_demo(args): | |
| 160 | 
             
                        regen_btn = gr.Button("π€οΈ Regenerate (ιθ―)")
         | 
| 161 | 
             
                        submit_btn = gr.Button("π Submit")
         | 
| 162 |  | 
| 163 | 
            -
                    def on_submit(chatbot_val, history_val, mic_val, text_val):
         | 
| 164 | 
             
                        chatbot2, history2, error = add_message(chatbot_val, history_val, mic_val, text_val)
         | 
| 165 | 
             
                        if error:
         | 
| 166 | 
             
                            gr.Warning(error)
         | 
| 167 | 
            -
                             | 
| 168 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 169 | 
             
                        chatbot2, history2 = predict(
         | 
| 170 | 
             
                            chatbot2, history2,
         | 
| 171 | 
            -
                             | 
| 172 | 
            -
                             | 
|  | |
| 173 | 
             
                        )
         | 
| 174 | 
            -
             | 
|  | |
|  | |
|  | |
| 175 |  | 
| 176 | 
             
                    submit_btn.click(
         | 
| 177 | 
             
                        fn=on_submit,
         | 
| 178 | 
            -
                        inputs=[chatbot, history, mic, text],
         | 
| 179 | 
            -
                        outputs=[chatbot, history, mic, text],
         | 
| 180 | 
             
                        concurrency_limit=4,
         | 
| 181 | 
             
                        concurrency_id="gpu_queue",
         | 
| 182 | 
             
                    )
         | 
| 183 |  | 
| 184 | 
            -
                    def on_clean(system_prompt_text):
         | 
| 185 | 
            -
                         | 
|  | |
|  | |
| 186 |  | 
| 187 | 
             
                    clean_btn.click(
         | 
| 188 | 
             
                        fn=on_clean,
         | 
| 189 | 
            -
                        inputs=[system_prompt],
         | 
| 190 | 
            -
                        outputs=[chatbot, history],
         | 
| 191 | 
             
                    )
         | 
| 192 |  | 
| 193 | 
            -
                    def on_regenerate(chatbot_val, history_val):
         | 
| 194 | 
             
                        # Drop last assistant turn(s) to regenerate
         | 
| 195 | 
             
                        while chatbot_val and chatbot_val[-1]["role"] == "assistant":
         | 
| 196 | 
             
                            chatbot_val.pop()
         | 
| @@ -199,13 +213,14 @@ def _launch_demo(args): | |
| 199 | 
             
                            history_val.pop()
         | 
| 200 | 
             
                        return predict(
         | 
| 201 | 
             
                            chatbot_val, history_val,
         | 
| 202 | 
            -
                             | 
| 203 | 
            -
                             | 
|  | |
| 204 | 
             
                        )
         | 
| 205 |  | 
| 206 | 
             
                    regen_btn.click(
         | 
| 207 | 
             
                        fn=on_regenerate,
         | 
| 208 | 
            -
                        inputs=[chatbot, history],
         | 
| 209 | 
             
                        outputs=[chatbot, history],
         | 
| 210 | 
             
                        concurrency_id="gpu_queue",
         | 
| 211 | 
             
                    )
         | 
|  | |
|  | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import shlex
         | 
| 3 | 
             
            import subprocess
         | 
|  | |
| 80 | 
             
            # Inference
         | 
| 81 | 
             
            # -----------------------
         | 
| 82 | 
             
            @spaces.GPU
         | 
| 83 | 
            +
            def predict(chatbot, history, prompt_wav_path, cache_dir, model_path="Step-Audio-2-mini"):
         | 
| 84 | 
             
                """
         | 
| 85 | 
             
                Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
         | 
| 86 | 
             
                Heavy models are created via _get_models() inside this process.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                `prompt_wav_path` is the CURRENT reference audio to condition on (can be user upload).
         | 
| 89 | 
             
                """
         | 
| 90 | 
             
                try:
         | 
| 91 | 
             
                    audio_model, token2wav = _get_models(model_path)
         | 
|  | |
| 101 | 
             
                        max_new_tokens=4096,
         | 
| 102 | 
             
                        temperature=0.7,
         | 
| 103 | 
             
                        repetition_penalty=1.05,
         | 
| 104 | 
            +
                        do_sample=True,
         | 
| 105 | 
             
                    )
         | 
| 106 | 
             
                    print(f"predict text={text!r}")
         | 
| 107 |  | 
| 108 | 
            +
                    # Convert tokens -> waveform bytes using token2wav with the *selected* prompt
         | 
| 109 | 
            +
                    prompt_path = prompt_wav_path if (prompt_wav_path and Path(prompt_wav_path).exists()) else None
         | 
| 110 | 
            +
                    audio_bytes = token2wav(audio_tokens, prompt_path)
         | 
| 111 |  | 
| 112 | 
             
                    # Persist to temp .wav for the UI
         | 
| 113 | 
             
                    audio_path = save_tmp_audio(audio_bytes, cache_dir)
         | 
|  | |
| 119 |  | 
| 120 | 
             
                except Exception:
         | 
| 121 | 
             
                    print(traceback.format_exc())
         | 
| 122 | 
            +
                    gr.Warning("Some error happened, please try again.")
         | 
| 123 |  | 
| 124 | 
             
                return chatbot, history
         | 
| 125 |  | 
|  | |
| 153 | 
             
                    # Initialize history with current system prompt value
         | 
| 154 | 
             
                    history = gr.State([{"role": "system", "content": system_prompt.value}])
         | 
| 155 |  | 
| 156 | 
            +
                    # NEW: keep track of the *current* prompt wav path (defaults to bundled voice)
         | 
| 157 | 
            +
                    current_prompt_wav = gr.State(args.prompt_wav)
         | 
| 158 | 
            +
             | 
| 159 | 
             
                    mic = gr.Audio(type="filepath", label="π€ Speak (optional)")
         | 
| 160 | 
             
                    text = gr.Textbox(placeholder="Enter message ...", label="π¬ Text")
         | 
| 161 |  | 
|  | |
| 164 | 
             
                        regen_btn = gr.Button("π€οΈ Regenerate (ιθ―)")
         | 
| 165 | 
             
                        submit_btn = gr.Button("π Submit")
         | 
| 166 |  | 
| 167 | 
            +
                    def on_submit(chatbot_val, history_val, mic_val, text_val, current_prompt):
         | 
| 168 | 
             
                        chatbot2, history2, error = add_message(chatbot_val, history_val, mic_val, text_val)
         | 
| 169 | 
             
                        if error:
         | 
| 170 | 
             
                            gr.Warning(error)
         | 
| 171 | 
            +
                            # keep state intact
         | 
| 172 | 
            +
                            return chatbot2, history2, None, None, current_prompt
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                        # Choose prompt: prefer latest user mic if present, else stick to remembered prompt
         | 
| 175 | 
            +
                        prompt_path = mic_val if (mic_val and Path(mic_val).exists()) else current_prompt
         | 
| 176 | 
            +
             | 
| 177 | 
             
                        chatbot2, history2 = predict(
         | 
| 178 | 
             
                            chatbot2, history2,
         | 
| 179 | 
            +
                            prompt_path,
         | 
| 180 | 
            +
                            args.cache_dir,
         | 
| 181 | 
            +
                            model_path=args.model_path,
         | 
| 182 | 
             
                        )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        # Clear inputs; remember the prompt we actually used
         | 
| 185 | 
            +
                        new_prompt_state = prompt_path
         | 
| 186 | 
            +
                        return chatbot2, history2, None, None, new_prompt_state
         | 
| 187 |  | 
| 188 | 
             
                    submit_btn.click(
         | 
| 189 | 
             
                        fn=on_submit,
         | 
| 190 | 
            +
                        inputs=[chatbot, history, mic, text, current_prompt_wav],
         | 
| 191 | 
            +
                        outputs=[chatbot, history, mic, text, current_prompt_wav],
         | 
| 192 | 
             
                        concurrency_limit=4,
         | 
| 193 | 
             
                        concurrency_id="gpu_queue",
         | 
| 194 | 
             
                    )
         | 
| 195 |  | 
| 196 | 
            +
                    def on_clean(system_prompt_text, _default_prompt):
         | 
| 197 | 
            +
                        # Reset chat and also reset the remembered prompt back to default
         | 
| 198 | 
            +
                        new_chatbot, new_history = reset_state(system_prompt_text)
         | 
| 199 | 
            +
                        return new_chatbot, new_history, _default_prompt
         | 
| 200 |  | 
| 201 | 
             
                    clean_btn.click(
         | 
| 202 | 
             
                        fn=on_clean,
         | 
| 203 | 
            +
                        inputs=[system_prompt, current_prompt_wav],
         | 
| 204 | 
            +
                        outputs=[chatbot, history, current_prompt_wav],
         | 
| 205 | 
             
                    )
         | 
| 206 |  | 
| 207 | 
            +
                    def on_regenerate(chatbot_val, history_val, current_prompt):
         | 
| 208 | 
             
                        # Drop last assistant turn(s) to regenerate
         | 
| 209 | 
             
                        while chatbot_val and chatbot_val[-1]["role"] == "assistant":
         | 
| 210 | 
             
                            chatbot_val.pop()
         | 
|  | |
| 213 | 
             
                            history_val.pop()
         | 
| 214 | 
             
                        return predict(
         | 
| 215 | 
             
                            chatbot_val, history_val,
         | 
| 216 | 
            +
                            current_prompt,           # use the remembered prompt for regen
         | 
| 217 | 
            +
                            args.cache_dir,
         | 
| 218 | 
            +
                            model_path=args.model_path,
         | 
| 219 | 
             
                        )
         | 
| 220 |  | 
| 221 | 
             
                    regen_btn.click(
         | 
| 222 | 
             
                        fn=on_regenerate,
         | 
| 223 | 
            +
                        inputs=[chatbot, history, current_prompt_wav],
         | 
| 224 | 
             
                        outputs=[chatbot, history],
         | 
| 225 | 
             
                        concurrency_id="gpu_queue",
         | 
| 226 | 
             
                    )
         | 
