Spaces:
Sleeping
Sleeping
import asyncio | |
import base64 | |
import json | |
import os | |
import pathlib | |
from typing import AsyncGenerator, Literal, Optional, List, Union | |
import gradio as gr | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse | |
from fastrtc import ( | |
AsyncStreamHandler, | |
Stream, | |
get_twilio_turn_credentials, | |
wait_for_item, | |
) | |
from google import genai | |
from google.genai.types import ( | |
LiveConnectConfig, | |
PrebuiltVoiceConfig, | |
SpeechConfig, | |
VoiceConfig, | |
) | |
from gradio.utils import get_space | |
from pydantic import BaseModel | |
current_dir = pathlib.Path(__file__).parent | |
load_dotenv() | |
# Define preset system prompts with more detailed and actionable instructions | |
SYSTEM_PROMPTS = { | |
"Default": "You are a helpful, harmless, and honest AI assistant. Provide clear, concise, and accurate information.", | |
"Behavior Expert": "You are an expert in behavior analysis, Positive Behavior Interventions and Supports, and Restorative approaches and practices. Your primary role is to help educators understand and address student behavior issues. Provide evidence-based strategies, empathetic insights, and practical solutions.", | |
"UDL Expert": "You are an expert in Universal Design for Learning (UDL). Your role is to help educators design units, lessons, and activities to be universally accessible. Provide detailed guidance on addressing student variability and potential learning barriers.", | |
"Learning Support Expert": "You are a learning support expert. Your role is to help educators adapt and support instruction based on student needs such as dyscalculia, dyslexia, executive functioning issues, low English proficiency, or other learning needs and differences. You also help devise ane implement IEPs, as well as accommodations." | |
} | |
def encode_audio(data: np.ndarray) -> str: | |
"""Encode Audio data to send to the server""" | |
return base64.b64encode(data.tobytes()).decode("UTF-8") | |
class GeminiHandler(AsyncStreamHandler): | |
"""Handler for the Gemini API""" | |
def __init__( | |
self, | |
expected_layout: Literal["mono"] = "mono", | |
output_sample_rate: int = 24000, | |
output_frame_size: int = 480, | |
) -> None: | |
super().__init__( | |
expected_layout, | |
output_sample_rate, | |
output_frame_size, | |
input_sample_rate=16000, | |
) | |
self.input_queue: asyncio.Queue = asyncio.Queue() | |
self.output_queue: asyncio.Queue = asyncio.Queue() | |
self.quit: asyncio.Event = asyncio.Event() | |
def copy(self) -> "GeminiHandler": | |
return GeminiHandler( | |
expected_layout="mono", | |
output_sample_rate=self.output_sample_rate, | |
output_frame_size=self.output_frame_size, | |
) | |
def prepare_system_instruction( | |
self, | |
prompt_key: Optional[str] = None, | |
custom_prompt: Optional[str] = None | |
) -> Optional[str]: | |
""" | |
Prepare system instruction based on preset or custom prompt. | |
Args: | |
prompt_key: Key for preset system prompt | |
custom_prompt: Custom user-defined system prompt | |
Returns: | |
System instruction as a string or None | |
""" | |
# Check for custom prompt first | |
if custom_prompt: | |
return custom_prompt | |
# Then check for preset prompt | |
if prompt_key and prompt_key in SYSTEM_PROMPTS: | |
return SYSTEM_PROMPTS[prompt_key] | |
return None | |
async def start_up(self): | |
if not self.phone_mode: | |
await self.wait_for_args() | |
api_key, voice_name, prompt_key, custom_prompt = self.latest_args[1:] | |
# Prepare system instruction | |
system_instruction = self.prepare_system_instruction(prompt_key, custom_prompt) | |
else: | |
api_key, voice_name = None, "Puck" | |
system_instruction = None | |
client = genai.Client( | |
api_key=api_key or os.getenv("GEMINI_API_KEY"), | |
http_options={"api_version": "v1alpha"}, | |
) | |
# Create config for the connection | |
config_kwargs = { | |
"response_modalities": ["AUDIO"], | |
"speech_config": SpeechConfig( | |
voice_config=VoiceConfig( | |
prebuilt_voice_config=PrebuiltVoiceConfig( | |
voice_name=voice_name, | |
) | |
) | |
) | |
} | |
# Add system instruction if available | |
if system_instruction: | |
config_kwargs["system_instruction"] = { | |
"parts": [{"text": system_instruction}] | |
} | |
# Create the configuration | |
config = LiveConnectConfig(**config_kwargs) | |
try: | |
async with client.aio.live.connect( | |
model="gemini-2.0-flash-exp", config=config | |
) as session: | |
async for audio in session.start_stream( | |
stream=self.stream(), mime_type="audio/pcm" | |
): | |
if audio.data: | |
array = np.frombuffer(audio.data, dtype=np.int16) | |
self.output_queue.put_nowait((self.output_sample_rate, array)) | |
except Exception as e: | |
print(f"Error in Gemini connection: {e}") | |
async def stream(self) -> AsyncGenerator[bytes, None]: | |
while not self.quit.is_set(): | |
try: | |
audio = await asyncio.wait_for(self.input_queue.get(), 0.1) | |
yield audio | |
except (asyncio.TimeoutError, TimeoutError): | |
pass | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
_, array = frame | |
array = array.squeeze() | |
audio_message = encode_audio(array) | |
self.input_queue.put_nowait(audio_message) | |
async def emit(self) -> tuple[int, np.ndarray] | None: | |
return await wait_for_item(self.output_queue) | |
def shutdown(self) -> None: | |
self.quit.set() | |
stream = Stream( | |
modality="audio", | |
mode="send-receive", | |
handler=GeminiHandler(), | |
rtc_configuration=get_twilio_turn_credentials() if get_space() else None, | |
concurrency_limit=5 if get_space() else None, | |
time_limit=90 if get_space() else None, | |
additional_inputs=[ | |
gr.Textbox( | |
label="API Key", | |
type="password", | |
value=os.getenv("GEMINI_API_KEY") if not get_space() else "", | |
), | |
gr.Dropdown( | |
label="Voice", | |
choices=[ | |
"Puck", | |
"Charon", | |
"Kore", | |
"Fenrir", | |
"Aoede", | |
], | |
value="Puck", | |
), | |
gr.Dropdown( | |
label="Preset Prompt", | |
choices=list(SYSTEM_PROMPTS.keys()), | |
value="Default", | |
), | |
gr.Textbox( | |
label="Custom Prompt", | |
placeholder="Enter a custom system prompt (overrides preset if not empty)", | |
value="", | |
), | |
], | |
) | |
class InputData(BaseModel): | |
webrtc_id: str | |
voice_name: str | |
api_key: str | |
prompt_key: str = "" | |
custom_prompt: str = "" | |
app = FastAPI() | |
stream.mount(app) | |
async def _(body: InputData): | |
stream.set_input(body.webrtc_id, body.api_key, body.voice_name, body.prompt_key, body.custom_prompt) | |
return {"status": "ok"} | |
async def index(): | |
rtc_config = get_twilio_turn_credentials() if get_space() else None | |
html_content = (current_dir / "index.html").read_text() | |
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) | |
html_content = html_content.replace("__SYSTEM_PROMPTS__", json.dumps(SYSTEM_PROMPTS)) | |
return HTMLResponse(content=html_content) | |
if __name__ == "__main__": | |
import os | |
if (mode := os.getenv("MODE")) == "UI": | |
stream.ui.launch(server_port=7860) | |
elif mode == "PHONE": | |
stream.fastphone(host="0.0.0.0", port=7860) | |
else: | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |