import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal, Optional, List, Union

import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import (
    AsyncStreamHandler,
    Stream,
    get_twilio_turn_credentials,
    wait_for_item,
)
from google import genai
from google.genai.types import (
    LiveConnectConfig,
    PrebuiltVoiceConfig,
    SpeechConfig,
    VoiceConfig,
)
from gradio.utils import get_space
from pydantic import BaseModel

current_dir = pathlib.Path(__file__).parent

load_dotenv()

# Define preset system prompts with more detailed and actionable instructions
SYSTEM_PROMPTS = {
    "Default": "You are a helpful, harmless, and honest AI assistant. Provide clear, concise, and accurate information.",
    "Behavior Expert": "You are an expert in behavior analysis, Positive Behavior Interventions and Supports, and Restorative approaches and practices. Your primary role is to help educators understand and address student behavior issues. Provide evidence-based strategies, empathetic insights, and practical solutions.",
    "UDL Expert": "You are an expert in Universal Design for Learning (UDL). Your role is to help educators design units, lessons, and activities to be universally accessible. Provide detailed guidance on addressing student variability and potential learning barriers.",
    "Learning Support Expert": "You are a learning support expert. Your role is to help educators adapt and support instruction based on student needs such as dyscalculia, dyslexia, executive functioning issues, low English proficiency, or other learning needs and differences. You also help devise ane implement IEPs, as well as accommodations."
}


def encode_audio(data: np.ndarray) -> str:
    """Encode Audio data to send to the server"""
    return base64.b64encode(data.tobytes()).decode("UTF-8")


class GeminiHandler(AsyncStreamHandler):
    """Handler for the Gemini API"""

    def __init__(
        self,
        expected_layout: Literal["mono"] = "mono",
        output_sample_rate: int = 24000,
        output_frame_size: int = 480,
    ) -> None:
        super().__init__(
            expected_layout,
            output_sample_rate,
            output_frame_size,
            input_sample_rate=16000,
        )
        self.input_queue: asyncio.Queue = asyncio.Queue()
        self.output_queue: asyncio.Queue = asyncio.Queue()
        self.quit: asyncio.Event = asyncio.Event()

    def copy(self) -> "GeminiHandler":
        return GeminiHandler(
            expected_layout="mono",
            output_sample_rate=self.output_sample_rate,
            output_frame_size=self.output_frame_size,
        )

    def prepare_system_instruction(
        self, 
        prompt_key: Optional[str] = None, 
        custom_prompt: Optional[str] = None
    ) -> Optional[str]:
        """
        Prepare system instruction based on preset or custom prompt.
        
        Args:
            prompt_key: Key for preset system prompt
            custom_prompt: Custom user-defined system prompt
        
        Returns:
            System instruction as a string or None
        """
        # Check for custom prompt first
        if custom_prompt:
            return custom_prompt
        
        # Then check for preset prompt
        if prompt_key and prompt_key in SYSTEM_PROMPTS:
            return SYSTEM_PROMPTS[prompt_key]
        
        return None

    async def start_up(self):
        if not self.phone_mode:
            await self.wait_for_args()
            api_key, voice_name, prompt_key, custom_prompt = self.latest_args[1:]
            
            # Prepare system instruction
            system_instruction = self.prepare_system_instruction(prompt_key, custom_prompt)
        else:
            api_key, voice_name = None, "Puck"
            system_instruction = None

        client = genai.Client(
            api_key=api_key or os.getenv("GEMINI_API_KEY"),
            http_options={"api_version": "v1alpha"},
        )

        # Create config for the connection
        config_kwargs = {
            "response_modalities": ["AUDIO"],
            "speech_config": SpeechConfig(
                voice_config=VoiceConfig(
                    prebuilt_voice_config=PrebuiltVoiceConfig(
                        voice_name=voice_name,
                    )
                )
            )
        }

        # Add system instruction if available
        if system_instruction:
            config_kwargs["system_instruction"] = {
                "parts": [{"text": system_instruction}]
            }
        
        # Create the configuration
        config = LiveConnectConfig(**config_kwargs)
        
        try:
            async with client.aio.live.connect(
                model="gemini-2.0-flash-exp", config=config
            ) as session:
                async for audio in session.start_stream(
                    stream=self.stream(), mime_type="audio/pcm"
                ):
                    if audio.data:
                        array = np.frombuffer(audio.data, dtype=np.int16)
                        self.output_queue.put_nowait((self.output_sample_rate, array))
        except Exception as e:
            print(f"Error in Gemini connection: {e}")

    async def stream(self) -> AsyncGenerator[bytes, None]:
        while not self.quit.is_set():
            try:
                audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
                yield audio
            except (asyncio.TimeoutError, TimeoutError):
                pass

    async def receive(self, frame: tuple[int, np.ndarray]) -> None:
        _, array = frame
        array = array.squeeze()
        audio_message = encode_audio(array)
        self.input_queue.put_nowait(audio_message)

    async def emit(self) -> tuple[int, np.ndarray] | None:
        return await wait_for_item(self.output_queue)

    def shutdown(self) -> None:
        self.quit.set()

stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=GeminiHandler(),
    rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
    concurrency_limit=5 if get_space() else None,
    time_limit=90 if get_space() else None,
    additional_inputs=[
        gr.Textbox(
            label="API Key",
            type="password",
            value=os.getenv("GEMINI_API_KEY") if not get_space() else "",
        ),
        gr.Dropdown(
            label="Voice",
            choices=[
                "Puck",
                "Charon",
                "Kore",
                "Fenrir",
                "Aoede",
            ],
            value="Puck",
        ),
        gr.Dropdown(
            label="Preset Prompt",
            choices=list(SYSTEM_PROMPTS.keys()),
            value="Default",
        ),
        gr.Textbox(
            label="Custom Prompt",
            placeholder="Enter a custom system prompt (overrides preset if not empty)",
            value="",
        ),
    ],
)


class InputData(BaseModel):
    webrtc_id: str
    voice_name: str
    api_key: str
    prompt_key: str = ""
    custom_prompt: str = ""


app = FastAPI()

stream.mount(app)


@app.post("/input_hook")
async def _(body: InputData):
    stream.set_input(body.webrtc_id, body.api_key, body.voice_name, body.prompt_key, body.custom_prompt)
    return {"status": "ok"}


@app.get("/")
async def index():
    rtc_config = get_twilio_turn_credentials() if get_space() else None
    html_content = (current_dir / "index.html").read_text()
    html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
    html_content = html_content.replace("__SYSTEM_PROMPTS__", json.dumps(SYSTEM_PROMPTS))
    return HTMLResponse(content=html_content)


if __name__ == "__main__":
    import os

    if (mode := os.getenv("MODE")) == "UI":
        stream.ui.launch(server_port=7860)
    elif mode == "PHONE":
        stream.fastphone(host="0.0.0.0", port=7860)
    else:
        import uvicorn
        uvicorn.run(app, host="0.0.0.0", port=7860)