import streamlit as st
import numpy as np
import requests
import base64
import tempfile
import os
import time
import traceback
import librosa
from pydub import AudioSegment
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import av
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions

API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

def run_vad(audio, sr):
    _st = time.time()
    try:
        audio = audio.astype(np.float32) / 32768.0
        sampling_rate = 16000
        if sr != sampling_rate:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)

        vad_parameters = {}
        vad_parameters = VadOptions(**vad_parameters)
        speech_chunks = get_speech_timestamps(audio, vad_parameters)
        audio = collect_chunks(audio, speech_chunks)
        duration_after_vad = audio.shape[0] / sampling_rate

        if sr != sampling_rate:
            # resample to original sampling rate
            vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
        else:
            vad_audio = audio
        vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
        vad_audio_bytes = vad_audio.tobytes()

        return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
    except Exception as e:
        msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}"
        print(msg)
        return -1, audio.tobytes(), round(time.time() - _st, 4)

def save_tmp_audio(audio_bytes):
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
        file_name = tmpfile.name
        audio = AudioSegment(
            data=audio_bytes,
            sample_width=2,
            frame_rate=16000,
            channels=1,
        )
        audio.export(file_name, format="wav")
        return file_name

def main():
    st.title("Chat Mini-Omni Demo")
    status = st.empty()

    if "audio_buffer" not in st.session_state:
        st.session_state.audio_buffer = []

    webrtc_ctx = webrtc_streamer(
        key="speech-to-text",
        mode=WebRtcMode.SENDONLY,
        audio_receiver_size=1024,
        rtc_configuration=RTCConfiguration(
            {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
        ),
        media_stream_constraints={"video": False, "audio": True},
    )

    if webrtc_ctx.audio_receiver:
        while True:
            try:
                audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1)
                sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16")
                st.session_state.audio_buffer.extend(sound_chunk)

                if len(st.session_state.audio_buffer) >= 16000:
                    duration_after_vad, vad_audio_bytes, vad_time = run_vad(
                        np.array(st.session_state.audio_buffer), 16000
                    )
                    st.session_state.audio_buffer = []
                    if duration_after_vad > 0:
                        st.session_state.messages.append(
                            {"role": "user", "content": "User audio"}
                        )
                        file_name = save_tmp_audio(vad_audio_bytes)
                        st.audio(file_name, format="audio/wav")

                        response = requests.post(API_URL, data=vad_audio_bytes)
                        assistant_audio_bytes = response.content
                        assistant_file_name = save_tmp_audio(assistant_audio_bytes)
                        st.audio(assistant_file_name, format="audio/wav")
                        st.session_state.messages.append(
                            {"role": "assistant", "content": "Assistant response"}
                        )
            except Exception as e:
                print(f"Error in audio processing: {e}")
                break

    if st.button("Process Audio"):
        if st.session_state.audio_buffer:
            duration_after_vad, vad_audio_bytes, vad_time = run_vad(
                np.array(st.session_state.audio_buffer), 16000
            )
            st.session_state.messages.append({"role": "user", "content": "User audio"})
            file_name = save_tmp_audio(vad_audio_bytes)
            st.audio(file_name, format="audio/wav")

            response = requests.post(API_URL, data=vad_audio_bytes)
            assistant_audio_bytes = response.content
            assistant_file_name = save_tmp_audio(assistant_audio_bytes)
            st.audio(assistant_file_name, format="audio/wav")
            st.session_state.messages.append(
                {"role": "assistant", "content": "Assistant response"}
            )
            st.session_state.audio_buffer = []

    if st.session_state.messages:
        for message in st.session_state.messages:
            if message["role"] == "user":
                st.write(f"User: {message['content']}")
            else:
                st.write(f"Assistant: {message['content']}")

if __name__ == "__main__":
    main()