Spaces:
Runtime error
Runtime error
| import base64 | |
| import io | |
| import os | |
| import tempfile | |
| import time | |
| import traceback | |
| from dataclasses import dataclass | |
| from queue import Queue | |
| from threading import Thread | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import requests | |
| from gradio_webrtc import StreamHandler, WebRTC | |
| from huggingface_hub import snapshot_download | |
| from pydub import AudioSegment | |
| from twilio.rest import Client | |
| from server import serve | |
| # from server import serve | |
| from utils.vad import VadOptions, collect_chunks, get_speech_timestamps | |
| repo_id = "gpt-omni/mini-omni" | |
| snapshot_download(repo_id, local_dir="./checkpoint", revision="main") | |
| IP = "0.0.0.0" | |
| PORT = 60808 | |
| thread = Thread(target=serve, daemon=True) | |
| thread.start() | |
| API_URL = "http://0.0.0.0:60808/chat" | |
| account_sid = os.environ.get("TWILIO_ACCOUNT_SID") | |
| auth_token = os.environ.get("TWILIO_AUTH_TOKEN") | |
| if account_sid and auth_token: | |
| client = Client(account_sid, auth_token) | |
| token = client.tokens.create() | |
| rtc_configuration = { | |
| "iceServers": token.ice_servers, | |
| "iceTransportPolicy": "relay", | |
| } | |
| else: | |
| rtc_configuration = None | |
| # recording parameters | |
| IN_CHANNELS = 1 | |
| IN_RATE = 24000 | |
| IN_CHUNK = 1024 | |
| IN_SAMPLE_WIDTH = 2 | |
| VAD_STRIDE = 0.5 | |
| # playing parameters | |
| OUT_CHANNELS = 1 | |
| OUT_RATE = 24000 | |
| OUT_SAMPLE_WIDTH = 2 | |
| OUT_CHUNK = 20 * 4096 | |
| def run_vad(ori_audio, sr): | |
| _st = time.time() | |
| try: | |
| audio = ori_audio | |
| audio = audio.astype(np.float32) / 32768.0 | |
| sampling_rate = 16000 | |
| if sr != sampling_rate: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) | |
| vad_parameters = {} | |
| vad_parameters = VadOptions(**vad_parameters) | |
| speech_chunks = get_speech_timestamps(audio, vad_parameters) | |
| audio = collect_chunks(audio, speech_chunks) | |
| duration_after_vad = audio.shape[0] / sampling_rate | |
| if sr != sampling_rate: | |
| # resample to original sampling rate | |
| vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) | |
| else: | |
| vad_audio = audio | |
| vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) | |
| vad_audio_bytes = vad_audio.tobytes() | |
| return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) | |
| except Exception as e: | |
| msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" | |
| print(msg) | |
| return -1, ori_audio, round(time.time() - _st, 4) | |
| def warm_up(): | |
| frames = np.zeros((1, 1600)) # 1024 frames of 2 bytes each | |
| _, frames, tcost = run_vad(frames, 16000) | |
| print(f"warm up done, time_cost: {tcost:.3f} s") | |
| # warm_up() | |
| class AppState: | |
| stream: np.ndarray | None = None | |
| sampling_rate: int = 0 | |
| pause_detected: bool = False | |
| started_talking: bool = False | |
| responding: bool = False | |
| stopped: bool = False | |
| buffer: np.ndarray | None = None | |
| def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: | |
| """Take in the stream, determine if a pause happened""" | |
| duration = len(audio) / sampling_rate | |
| dur_vad, _, _ = run_vad(audio, sampling_rate) | |
| if duration >= 0.60: | |
| if dur_vad > 0.2 and not state.started_talking: | |
| print("started talking") | |
| state.started_talking = True | |
| if state.started_talking: | |
| if state.stream is None: | |
| state.stream = audio | |
| else: | |
| state.stream = np.concatenate((state.stream, audio)) | |
| state.buffer = None | |
| if dur_vad < 0.1 and state.started_talking: | |
| segment = AudioSegment( | |
| state.stream.tobytes(), | |
| frame_rate=sampling_rate, | |
| sample_width=audio.dtype.itemsize, | |
| channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| segment.export(f.name, format="wav") | |
| print("input file written", f.name) | |
| return True | |
| return False | |
| def speaking(audio_bytes: str): | |
| base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") | |
| files = {"audio": base64_encoded} | |
| byte_buffer = b"" | |
| with requests.post(API_URL, json=files, stream=True) as response: | |
| try: | |
| for chunk in response.iter_content(chunk_size=OUT_CHUNK): | |
| if chunk: | |
| # Create an audio segment from the numpy array | |
| byte_buffer += chunk | |
| audio_segment = AudioSegment( | |
| chunk + b"\x00" if len(chunk) % 2 != 0 else chunk, | |
| frame_rate=OUT_RATE, | |
| sample_width=OUT_SAMPLE_WIDTH, | |
| channels=OUT_CHANNELS, | |
| ) | |
| # Export the audio segment to a numpy array | |
| audio_np = np.array(audio_segment.get_array_of_samples()) | |
| yield audio_np.reshape(1, -1) | |
| all_output_audio = AudioSegment( | |
| byte_buffer, | |
| frame_rate=OUT_RATE, | |
| sample_width=OUT_SAMPLE_WIDTH, | |
| channels=1, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| all_output_audio.export(f.name, format="wav") | |
| print("output file written", f.name) | |
| except Exception as e: | |
| raise gr.Error(f"Error during audio streaming: {e}") | |
| def process_audio(audio: tuple, state: AppState) -> None: | |
| frame_rate, array = audio | |
| array = np.squeeze(array) | |
| if not state.sampling_rate: | |
| state.sampling_rate = frame_rate | |
| if state.buffer is None: | |
| state.buffer = array | |
| else: | |
| state.buffer = np.concatenate((state.buffer, array)) | |
| pause_detected = determine_pause(state.buffer, state.sampling_rate, state) | |
| state.pause_detected = pause_detected | |
| def response(state: AppState): | |
| if not state.pause_detected and not state.started_talking: | |
| return None | |
| audio_buffer = io.BytesIO() | |
| segment = AudioSegment( | |
| state.stream.tobytes(), | |
| frame_rate=state.sampling_rate, | |
| sample_width=state.stream.dtype.itemsize, | |
| channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), | |
| ) | |
| segment.export(audio_buffer, format="wav") | |
| for numpy_array in speaking(audio_buffer.getvalue()): | |
| yield (OUT_RATE, numpy_array, "mono") | |
| class OmniHandler(StreamHandler): | |
| def __init__(self) -> None: | |
| super().__init__( | |
| expected_layout="mono", output_sample_rate=OUT_RATE, output_frame_size=480 | |
| ) | |
| self.chunk_queue = Queue() | |
| self.state = AppState() | |
| self.generator = None | |
| self.duration = 0 | |
| def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| if self.state.responding: | |
| return | |
| process_audio(frame, self.state) | |
| if self.state.pause_detected: | |
| self.chunk_queue.put(True) | |
| def reset(self): | |
| self.generator = None | |
| self.state = AppState() | |
| self.duration = 0 | |
| def emit(self): | |
| if not self.generator: | |
| self.chunk_queue.get() | |
| self.state.responding = True | |
| self.generator = response(self.state) | |
| try: | |
| return next(self.generator) | |
| except StopIteration: | |
| self.reset() | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <h1 style='text-align: center'> | |
| Omni Chat (Powered by WebRTC ⚡️) | |
| </h1> | |
| """ | |
| ) | |
| with gr.Column(): | |
| with gr.Group(): | |
| audio = WebRTC( | |
| label="Stream", | |
| rtc_configuration=rtc_configuration, | |
| mode="send-receive", | |
| modality="audio", | |
| ) | |
| audio.stream(fn=OmniHandler(), inputs=[audio], outputs=[audio], time_limit=60) | |
| demo.launch(ssr_mode=False) | |