import re import gradio as gr import numpy as np import os import io import wave import threading import subprocess import sys import time from huggingface_hub import snapshot_download from tools.fish_e2e import FishE2EAgent, FishE2EEventType from tools.schema import ServeMessage, ServeTextPart, ServeVQPart # Download Weights os.makedirs("checkpoints", exist_ok=True) snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4") snapshot_download(repo_id="fishaudio/fish-agent-v0.1-3b", local_dir="./checkpoints/fish-agent-v0.1-3b") SYSTEM_PROMPT = 'You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user\'s speech, then answer it in the following format: "Question: [USER_SPEECH]\n\nResponse: [YOUR_RESPONSE]\n"。You are required to use the following voice in this conversation.' class ChatState: def __init__(self): self.conversation = [] self.added_systext = False self.added_sysaudio = False def get_history(self): results = [] for msg in self.conversation: results.append({"role": msg.role, "content": self.repr_message(msg)}) # Process assistant messages to extract questions and update user messages for i, msg in enumerate(results): if msg["role"] == "assistant": match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) if match and i > 0 and results[i - 1]["role"] == "user": # Update previous user message with extracted question results[i - 1]["content"] += "\n" + match.group(1) # Remove the Question/Answer format from assistant message msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] return results def repr_message(self, msg: ServeMessage): response = "" for part in msg.parts: if isinstance(part, ServeTextPart): response += part.text elif isinstance(part, ServeVQPart): response += f"<audio {len(part.codes[0]) / 21:.2f}s>" return response def clear_fn(): return [], ChatState(), None, None, None def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(channels) wav_file.setsampwidth(bit_depth // 8) wav_file.setframerate(sample_rate) wav_header_bytes = buffer.getvalue() buffer.close() return wav_header_bytes async def process_audio_input( sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str ): if audio_input is None and not text_input: raise gr.Error("No input provided") agent = FishE2EAgent() # Create new agent instance for each request # Convert audio input to numpy array if isinstance(audio_input, tuple): sr, audio_data = audio_input elif text_input: sr = 44100 audio_data = None else: raise gr.Error("Invalid audio format") if isinstance(sys_audio_input, tuple): sr, sys_audio_data = sys_audio_input else: sr = 44100 sys_audio_data = None def append_to_chat_ctx( part: ServeTextPart | ServeVQPart, role: str = "assistant" ) -> None: if not state.conversation or state.conversation[-1].role != role: state.conversation.append(ServeMessage(role=role, parts=[part])) else: state.conversation[-1].parts.append(part) if state.added_systext is False and sys_text_input: state.added_systext = True append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system") if text_input: append_to_chat_ctx(ServeTextPart(text=text_input), role="user") audio_data = None result_audio = b"" async for event in agent.stream( sys_audio_data, audio_data, sr, 1, chat_ctx={ "messages": state.conversation, "added_sysaudio": state.added_sysaudio, }, ): if event.type == FishE2EEventType.USER_CODES: append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user") elif event.type == FishE2EEventType.SPEECH_SEGMENT: result_audio += event.frame.data np_audio = np.frombuffer(result_audio, dtype=np.int16) append_to_chat_ctx(ServeVQPart(codes=event.vq_codes)) yield state.get_history(), (44100, np_audio), None, None elif event.type == FishE2EEventType.TEXT_SEGMENT: append_to_chat_ctx(ServeTextPart(text=event.text)) if result_audio: np_audio = np.frombuffer(result_audio, dtype=np.int16) yield state.get_history(), (44100, np_audio), None, None else: yield state.get_history(), None, None, None np_audio = np.frombuffer(result_audio, dtype=np.int16) yield state.get_history(), (44100, np_audio), None, None async def process_text_input( sys_audio_input, sys_text_input, state: ChatState, text_input: str ): async for event in process_audio_input( sys_audio_input, sys_text_input, None, state, text_input ): yield event def create_demo(): with gr.Blocks() as demo: state = gr.State(ChatState()) with gr.Row(): # Left column (70%) for chatbot and notes with gr.Column(scale=7): chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False, height=600, type="messages", ) notes = gr.Markdown( """ # Fish Agent 1. This demo is the Fish Audio self-developed end-to-end language model Fish Agent 3B version. 2. You can find the code and weights in our official repository, but all related content is released under the CC BY-NC-SA 4.0 license. 3. The demo is an early beta version, and inference speed is yet to be optimized. # Features 1. This model automatically integrates ASR and TTS components, requiring no external models, making it truly end-to-end rather than a three-stage process (ASR+LLM+TTS). 2. The model can use reference audio to control speaking voice. 3. It can generate audio with strong emotions and prosody. """ ) # Right column (30%) for controls with gr.Column(scale=3): sys_audio_input = gr.Audio( sources=["upload"], type="numpy", label="Give a timbre for your assistant", ) sys_text_input = gr.Textbox( label="What is your assistant's role?", value=SYSTEM_PROMPT, type="text", ) audio_input = gr.Audio( sources=["microphone"], type="numpy", label="Speak your message" ) text_input = gr.Textbox(label="Or type your message", type="text",value="Can you give a brief introduction of yourself?") output_audio = gr.Audio( label="Assistant's Voice", type="numpy", ) send_button = gr.Button("Send", variant="primary") clear_button = gr.Button("Clear") # Event handlers audio_input.stop_recording( process_audio_input, inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input], outputs=[chatbot, output_audio, audio_input, text_input], show_progress=True, ) send_button.click( process_text_input, inputs=[sys_audio_input, sys_text_input, state, text_input], outputs=[chatbot, output_audio, audio_input, text_input], show_progress=True, ) text_input.submit( process_text_input, inputs=[sys_audio_input, sys_text_input, state, text_input], outputs=[chatbot, output_audio, audio_input, text_input], show_progress=True, ) clear_button.click( clear_fn, inputs=[], outputs=[chatbot, state, audio_input, output_audio, text_input], ) return demo def run_api(): subprocess.run([sys.executable, "-m", "tools.api"]) if __name__ == "__main__": # 创建并启动 API 线程 api_thread = threading.Thread(target=run_api, daemon=True) api_thread.start() # 给 API 一些时间启动 time.sleep(90) # 创建并启动 Gradio demo demo = create_demo() demo.launch(share=True)