File size: 5,153 Bytes
9616027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import numpy as np
import requests
import base64
import tempfile
import os
import time
import traceback
import librosa
from pydub import AudioSegment
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import av
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions

API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

def run_vad(audio, sr):
    _st = time.time()
    try:
        audio = audio.astype(np.float32) / 32768.0
        sampling_rate = 16000
        if sr != sampling_rate:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)

        vad_parameters = {}
        vad_parameters = VadOptions(**vad_parameters)
        speech_chunks = get_speech_timestamps(audio, vad_parameters)
        audio = collect_chunks(audio, speech_chunks)
        duration_after_vad = audio.shape[0] / sampling_rate

        if sr != sampling_rate:
            # resample to original sampling rate
            vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
        else:
            vad_audio = audio
        vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
        vad_audio_bytes = vad_audio.tobytes()

        return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
    except Exception as e:
        msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}"
        print(msg)
        return -1, audio.tobytes(), round(time.time() - _st, 4)

def save_tmp_audio(audio_bytes):
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
        file_name = tmpfile.name
        audio = AudioSegment(
            data=audio_bytes,
            sample_width=2,
            frame_rate=16000,
            channels=1,
        )
        audio.export(file_name, format="wav")
        return file_name

def main():
    st.title("Chat Mini-Omni Demo")
    status = st.empty()

    if "audio_buffer" not in st.session_state:
        st.session_state.audio_buffer = []

    webrtc_ctx = webrtc_streamer(
        key="speech-to-text",
        mode=WebRtcMode.SENDONLY,
        audio_receiver_size=1024,
        rtc_configuration=RTCConfiguration(
            {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
        ),
        media_stream_constraints={"video": False, "audio": True},
    )

    if webrtc_ctx.audio_receiver:
        while True:
            try:
                audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1)
                sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16")
                st.session_state.audio_buffer.extend(sound_chunk)

                if len(st.session_state.audio_buffer) >= 16000:
                    duration_after_vad, vad_audio_bytes, vad_time = run_vad(
                        np.array(st.session_state.audio_buffer), 16000
                    )
                    st.session_state.audio_buffer = []
                    if duration_after_vad > 0:
                        st.session_state.messages.append(
                            {"role": "user", "content": "User audio"}
                        )
                        file_name = save_tmp_audio(vad_audio_bytes)
                        st.audio(file_name, format="audio/wav")

                        response = requests.post(API_URL, data=vad_audio_bytes)
                        assistant_audio_bytes = response.content
                        assistant_file_name = save_tmp_audio(assistant_audio_bytes)
                        st.audio(assistant_file_name, format="audio/wav")
                        st.session_state.messages.append(
                            {"role": "assistant", "content": "Assistant response"}
                        )
            except Exception as e:
                print(f"Error in audio processing: {e}")
                break

    if st.button("Process Audio"):
        if st.session_state.audio_buffer:
            duration_after_vad, vad_audio_bytes, vad_time = run_vad(
                np.array(st.session_state.audio_buffer), 16000
            )
            st.session_state.messages.append({"role": "user", "content": "User audio"})
            file_name = save_tmp_audio(vad_audio_bytes)
            st.audio(file_name, format="audio/wav")

            response = requests.post(API_URL, data=vad_audio_bytes)
            assistant_audio_bytes = response.content
            assistant_file_name = save_tmp_audio(assistant_audio_bytes)
            st.audio(assistant_file_name, format="audio/wav")
            st.session_state.messages.append(
                {"role": "assistant", "content": "Assistant response"}
            )
            st.session_state.audio_buffer = []

    if st.session_state.messages:
        for message in st.session_state.messages:
            if message["role"] == "user":
                st.write(f"User: {message['content']}")
            else:
                st.write(f"Assistant: {message['content']}")

if __name__ == "__main__":
    main()