omni-docker / webui /omni_streamlit.py
victor's picture
victor HF Staff
feat: Update Dockerfile and requirements.txt to resolve PyAudio build issues
9616027
import streamlit as st
import numpy as np
import requests
import base64
import tempfile
import os
import time
import traceback
import librosa
from pydub import AudioSegment
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import av
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
def run_vad(audio, sr):
_st = time.time()
try:
audio = audio.astype(np.float32) / 32768.0
sampling_rate = 16000
if sr != sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
vad_parameters = {}
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate
if sr != sampling_rate:
# resample to original sampling rate
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
else:
vad_audio = audio
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
vad_audio_bytes = vad_audio.tobytes()
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
except Exception as e:
msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}"
print(msg)
return -1, audio.tobytes(), round(time.time() - _st, 4)
def save_tmp_audio(audio_bytes):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
file_name = tmpfile.name
audio = AudioSegment(
data=audio_bytes,
sample_width=2,
frame_rate=16000,
channels=1,
)
audio.export(file_name, format="wav")
return file_name
def main():
st.title("Chat Mini-Omni Demo")
status = st.empty()
if "audio_buffer" not in st.session_state:
st.session_state.audio_buffer = []
webrtc_ctx = webrtc_streamer(
key="speech-to-text",
mode=WebRtcMode.SENDONLY,
audio_receiver_size=1024,
rtc_configuration=RTCConfiguration(
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
),
media_stream_constraints={"video": False, "audio": True},
)
if webrtc_ctx.audio_receiver:
while True:
try:
audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1)
sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16")
st.session_state.audio_buffer.extend(sound_chunk)
if len(st.session_state.audio_buffer) >= 16000:
duration_after_vad, vad_audio_bytes, vad_time = run_vad(
np.array(st.session_state.audio_buffer), 16000
)
st.session_state.audio_buffer = []
if duration_after_vad > 0:
st.session_state.messages.append(
{"role": "user", "content": "User audio"}
)
file_name = save_tmp_audio(vad_audio_bytes)
st.audio(file_name, format="audio/wav")
response = requests.post(API_URL, data=vad_audio_bytes)
assistant_audio_bytes = response.content
assistant_file_name = save_tmp_audio(assistant_audio_bytes)
st.audio(assistant_file_name, format="audio/wav")
st.session_state.messages.append(
{"role": "assistant", "content": "Assistant response"}
)
except Exception as e:
print(f"Error in audio processing: {e}")
break
if st.button("Process Audio"):
if st.session_state.audio_buffer:
duration_after_vad, vad_audio_bytes, vad_time = run_vad(
np.array(st.session_state.audio_buffer), 16000
)
st.session_state.messages.append({"role": "user", "content": "User audio"})
file_name = save_tmp_audio(vad_audio_bytes)
st.audio(file_name, format="audio/wav")
response = requests.post(API_URL, data=vad_audio_bytes)
assistant_audio_bytes = response.content
assistant_file_name = save_tmp_audio(assistant_audio_bytes)
st.audio(assistant_file_name, format="audio/wav")
st.session_state.messages.append(
{"role": "assistant", "content": "Assistant response"}
)
st.session_state.audio_buffer = []
if st.session_state.messages:
for message in st.session_state.messages:
if message["role"] == "user":
st.write(f"User: {message['content']}")
else:
st.write(f"Assistant: {message['content']}")
if __name__ == "__main__":
main()