Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os, time, re, json, base64, asyncio, threading, uuid, io | |
| import numpy as np | |
| import soundfile as sf | |
| from pydub import AudioSegment | |
| from openai import OpenAI | |
| from websockets import connect, Data, ClientConnection | |
| from dotenv import load_dotenv | |
| # ============ Load Secrets ============ | |
| load_dotenv() | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ASSISTANT_ID = os.getenv("ASSISTANT_ID") | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"} | |
| WS_URI = "wss://api.openai.com/v1/realtime?intent=transcription" | |
| connections = {} | |
| # ============ WebSocket Client ============ | |
| class WebSocketClient: | |
| def __init__(self, uri, headers, client_id): | |
| self.uri, self.headers, self.client_id = uri, headers, client_id | |
| self.websocket = None | |
| self.queue = asyncio.Queue(maxsize=10) | |
| self.transcript = "" | |
| async def connect(self): | |
| self.websocket = await connect(self.uri, additional_headers=self.headers) | |
| with open("openai_transcription_settings.json", "r") as f: | |
| await self.websocket.send(f.read()) | |
| await asyncio.gather(self.receive_messages(), self.send_audio_chunks()) | |
| def run(self): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(self.connect()) | |
| async def send_audio_chunks(self): | |
| while True: | |
| sr, arr = await self.queue.get() | |
| if arr.ndim > 1: arr = arr.mean(axis=1) | |
| arr = (arr / np.max(np.abs(arr))) if np.max(np.abs(arr)) > 0 else arr | |
| int16 = (arr * 32767).astype(np.int16) | |
| buf = io.BytesIO(); sf.write(buf, int16, sr, format='WAV', subtype='PCM_16') | |
| audio = AudioSegment.from_file(buf, format="wav").set_frame_rate(24000) | |
| out = io.BytesIO(); audio.export(out, format="wav"); out.seek(0) | |
| await self.websocket.send(json.dumps({ | |
| "type": "input_audio_buffer.append", | |
| "audio": base64.b64encode(out.read()).decode() | |
| })) | |
| async def receive_messages(self): | |
| async for msg in self.websocket: | |
| data = json.loads(msg) | |
| if data["type"] == "conversation.item.input_audio_transcription.delta": | |
| self.transcript += data["delta"] | |
| def enqueue_audio_chunk(self, sr, arr): | |
| if not self.queue.full(): | |
| asyncio.run_coroutine_threadsafe(self.queue.put((sr, arr)), asyncio.get_event_loop()) | |
| def create_ws(): | |
| cid = str(uuid.uuid4()) | |
| client = WebSocketClient(WS_URI, HEADERS, cid) | |
| threading.Thread(target=client.run, daemon=True).start() | |
| connections[cid] = client | |
| return cid | |
| def send_audio(chunk, cid): | |
| if cid not in connections: return "Connecting..." | |
| sr, arr = chunk | |
| connections[cid].enqueue_audio_chunk(sr, arr) | |
| return connections[cid].transcript | |
| def clear_transcript(cid): | |
| if cid in connections: connections[cid].transcript = "" | |
| return "" | |
| # ============ Chat Assistant ============ | |
| def handle_chat(user_input, history, thread_id, image_url): | |
| if not OPENAI_API_KEY or not ASSISTANT_ID: | |
| return "β Missing secrets!", history, thread_id, image_url | |
| try: | |
| if thread_id is None: | |
| thread = client.beta.threads.create() | |
| thread_id = thread.id | |
| client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input) | |
| run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID) | |
| while True: | |
| status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id) | |
| if status.status == "completed": break | |
| time.sleep(1) | |
| msgs = client.beta.threads.messages.list(thread_id=thread_id) | |
| for msg in reversed(msgs.data): | |
| if msg.role == "assistant": | |
| content = msg.content[0].text.value | |
| history.append((user_input, content)) | |
| match = re.search( | |
| r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png', | |
| content | |
| ) | |
| if match: image_url = match.group(0) | |
| break | |
| return "", history, thread_id, image_url | |
| except Exception as e: | |
| return f"β {e}", history, thread_id, image_url | |
| # ============ Gradio UI ============ | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# π Document AI Assistant") | |
| # STATES | |
| chat_state = gr.State([]) | |
| thread_state = gr.State() | |
| image_state = gr.State() | |
| client_id = gr.State() | |
| voice_enabled = gr.State(False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| image_display = gr.Image(label="πΌοΈ Document", type="filepath", show_download_button=False) | |
| with gr.Column(scale=1.4): | |
| chat = gr.Chatbot(label="π¬ Chat", height=460) | |
| with gr.Row(): | |
| user_prompt = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=6) | |
| mic_toggle_btn = gr.Button("ποΈ", scale=1) | |
| send_btn = gr.Button("Send", variant="primary", scale=2) | |
| with gr.Accordion("π€ Voice Transcription", open=False) as voice_section: | |
| with gr.Row(): | |
| voice_input = gr.Audio(label="Mic", streaming=True) | |
| voice_transcript = gr.Textbox(label="Transcript", lines=2, interactive=False) | |
| clear_btn = gr.Button("π§Ή Clear Transcript") | |
| # FUNCTIONAL CONNECTIONS | |
| def toggle_voice(curr): | |
| return not curr, gr.update(visible=not curr) | |
| mic_toggle_btn.click(fn=toggle_voice, inputs=voice_enabled, outputs=[voice_enabled, voice_section]) | |
| send_btn.click(fn=handle_chat, | |
| inputs=[user_prompt, chat_state, thread_state, image_state], | |
| outputs=[user_prompt, chat, thread_state, image_state]) | |
| image_state.change(fn=lambda x: x, inputs=image_state, outputs=image_display) | |
| voice_input.stream(fn=send_audio, inputs=[voice_input, client_id], outputs=voice_transcript, stream_every=0.5) | |
| clear_btn.click(fn=clear_transcript, inputs=[client_id], outputs=voice_transcript) | |
| app.load(fn=create_ws, outputs=[client_id]) | |
| app.launch() | |