import re
import gradio as gr
import numpy as np
import os
import io
import wave
import threading
import subprocess
import sys
import time

from huggingface_hub import snapshot_download
from tools.fish_e2e import FishE2EAgent, FishE2EEventType
from tools.schema import ServeMessage, ServeTextPart, ServeVQPart

# Download Weights
os.makedirs("checkpoints", exist_ok=True)
snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
snapshot_download(repo_id="fishaudio/fish-agent-v0.1-3b", local_dir="./checkpoints/fish-agent-v0.1-3b")
SYSTEM_PROMPT = 'You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user\'s speech, then answer it in the following format: "Question: [USER_SPEECH]\n\nResponse: [YOUR_RESPONSE]\n"。You are required to use the following voice in this conversation.'

class ChatState:
    def __init__(self):
        self.conversation = []
        self.added_systext = False
        self.added_sysaudio = False

    def get_history(self):
        results = []
        for msg in self.conversation:
            results.append({"role": msg.role, "content": self.repr_message(msg)})

        # Process assistant messages to extract questions and update user messages
        for i, msg in enumerate(results):
            if msg["role"] == "assistant":
                match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
                if match and i > 0 and results[i - 1]["role"] == "user":
                    # Update previous user message with extracted question
                    results[i - 1]["content"] += "\n" + match.group(1)
                    # Remove the Question/Answer format from assistant message
                    msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
        return results

    def repr_message(self, msg: ServeMessage):
        response = ""
        for part in msg.parts:
            if isinstance(part, ServeTextPart):
                response += part.text
            elif isinstance(part, ServeVQPart):
                response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
        return response


def clear_fn():
    return [], ChatState(), None, None, None

def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
    buffer = io.BytesIO()

    with wave.open(buffer, "wb") as wav_file:
        wav_file.setnchannels(channels)
        wav_file.setsampwidth(bit_depth // 8)
        wav_file.setframerate(sample_rate)

    wav_header_bytes = buffer.getvalue()
    buffer.close()
    return wav_header_bytes

async def process_audio_input(
    sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
):
    if audio_input is None and not text_input:
        raise gr.Error("No input provided")

    agent = FishE2EAgent()  # Create new agent instance for each request

    # Convert audio input to numpy array
    if isinstance(audio_input, tuple):
        sr, audio_data = audio_input
    elif text_input:
        sr = 44100
        audio_data = None
    else:
        raise gr.Error("Invalid audio format")

    if isinstance(sys_audio_input, tuple):
        sr, sys_audio_data = sys_audio_input
    else:
        sr = 44100
        sys_audio_data = None

    def append_to_chat_ctx(
        part: ServeTextPart | ServeVQPart, role: str = "assistant"
    ) -> None:
        if not state.conversation or state.conversation[-1].role != role:
            state.conversation.append(ServeMessage(role=role, parts=[part]))
        else:
            state.conversation[-1].parts.append(part)

    if state.added_systext is False and sys_text_input:
        state.added_systext = True
        append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
    if text_input:
        append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
        audio_data = None

    result_audio = b""
    async for event in agent.stream(
        sys_audio_data,
        audio_data,
        sr,
        1,
        chat_ctx={
            "messages": state.conversation,
            "added_sysaudio": state.added_sysaudio,
        },
    ):
        if event.type == FishE2EEventType.USER_CODES:
            append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
        elif event.type == FishE2EEventType.SPEECH_SEGMENT:
            result_audio += event.frame.data
            np_audio = np.frombuffer(result_audio, dtype=np.int16)
            append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
            yield state.get_history(), (44100, np_audio), None, None
        elif event.type == FishE2EEventType.TEXT_SEGMENT:
            append_to_chat_ctx(ServeTextPart(text=event.text))
            if result_audio:
                np_audio = np.frombuffer(result_audio, dtype=np.int16)
                yield state.get_history(), (44100, np_audio), None, None
            else:
                yield state.get_history(), None, None, None

    np_audio = np.frombuffer(result_audio, dtype=np.int16)
    yield state.get_history(), (44100, np_audio), None, None


async def process_text_input(
    sys_audio_input, sys_text_input, state: ChatState, text_input: str
):
    async for event in process_audio_input(
        sys_audio_input, sys_text_input, None, state, text_input
    ):
        yield event


def create_demo():
    with gr.Blocks() as demo:
        state = gr.State(ChatState())

        with gr.Row():
            # Left column (70%) for chatbot and notes
            with gr.Column(scale=7):
                chatbot = gr.Chatbot(
                    [],
                    elem_id="chatbot",
                    bubble_full_width=False,
                    height=600,
                    type="messages",
                )

                notes = gr.Markdown(
                    """
                # Fish Agent
                1. This demo is the Fish Audio self-developed end-to-end language model Fish Agent 3B version.
                2. You can find the code and weights in our official repository, but all related content is released under the CC BY-NC-SA 4.0 license.
                3. The demo is an early beta version, and inference speed is yet to be optimized.
                # Features
                1. This model automatically integrates ASR and TTS components, requiring no external models, making it truly end-to-end rather than a three-stage process (ASR+LLM+TTS).
                2. The model can use reference audio to control speaking voice.
                3. It can generate audio with strong emotions and prosody.
                """
                )

            # Right column (30%) for controls
            with gr.Column(scale=3):
                sys_audio_input = gr.Audio(
                    sources=["upload"],
                    type="numpy",
                    label="Give a timbre for your assistant",
                )
                sys_text_input = gr.Textbox(
                    label="What is your assistant's role?",
                    value=SYSTEM_PROMPT,
                    type="text",
                )
                audio_input = gr.Audio(
                    sources=["microphone"], type="numpy", label="Speak your message"
                )

                text_input = gr.Textbox(label="Or type your message", type="text",value="Can you give a brief introduction of yourself?")

                output_audio = gr.Audio(
                    label="Assistant's Voice", 
                    type="numpy",
                )

                send_button = gr.Button("Send", variant="primary")
                clear_button = gr.Button("Clear")

        # Event handlers
        audio_input.stop_recording(
            process_audio_input,
            inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
            outputs=[chatbot, output_audio, audio_input, text_input],
            show_progress=True,
        )

        send_button.click(
            process_text_input,
            inputs=[sys_audio_input, sys_text_input, state, text_input],
            outputs=[chatbot, output_audio, audio_input, text_input],
            show_progress=True,
        )

        text_input.submit(
            process_text_input,
            inputs=[sys_audio_input, sys_text_input, state, text_input],
            outputs=[chatbot, output_audio, audio_input, text_input],
            show_progress=True,
        )

        clear_button.click(
            clear_fn,
            inputs=[],
            outputs=[chatbot, state, audio_input, output_audio, text_input],
        )

    return demo

def run_api():
    subprocess.run([sys.executable, "-m", "tools.api"])

if __name__ == "__main__":
    
    # 创建并启动 API 线程
    api_thread = threading.Thread(target=run_api, daemon=True)
    api_thread.start()

    # 给 API 一些时间启动
    time.sleep(90)

    # 创建并启动 Gradio demo
    demo = create_demo()
    demo.launch(share=True)