jeremierostan's picture
Update app.py
80abd37 verified
raw
history blame
8.14 kB
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal, Optional, List, Union
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import (
AsyncStreamHandler,
Stream,
get_twilio_turn_credentials,
wait_for_item,
)
from google import genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
)
from gradio.utils import get_space
from pydantic import BaseModel
current_dir = pathlib.Path(__file__).parent
load_dotenv()
# Define preset system prompts with more detailed and actionable instructions
SYSTEM_PROMPTS = {
"Default": "You are a helpful, harmless, and honest AI assistant. Provide clear, concise, and accurate information.",
"Behavior Expert": "You are an expert in behavior analysis, Positive Behavior Interventions and Supports, and Restorative approaches and practices. Your primary role is to help educators understand and address student behavior issues. Provide evidence-based strategies, empathetic insights, and practical solutions.",
"UDL Expert": "You are an expert in Universal Design for Learning (UDL). Your role is to help educators design units, lessons, and activities to be universally accessible. Provide detailed guidance on addressing student variability and potential learning barriers.",
"Learning Support Expert": "You are a learning support expert. Your role is to help educators adapt and support instruction based on student needs such as dyscalculia, dyslexia, executive functioning issues, low English proficiency, or other learning needs and differences. You also help devise ane implement IEPs, as well as accommodations."
}
def encode_audio(data: np.ndarray) -> str:
"""Encode Audio data to send to the server"""
return base64.b64encode(data.tobytes()).decode("UTF-8")
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API"""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
def prepare_system_instruction(
self,
prompt_key: Optional[str] = None,
custom_prompt: Optional[str] = None
) -> Optional[str]:
"""
Prepare system instruction based on preset or custom prompt.
Args:
prompt_key: Key for preset system prompt
custom_prompt: Custom user-defined system prompt
Returns:
System instruction as a string or None
"""
# Check for custom prompt first
if custom_prompt:
return custom_prompt
# Then check for preset prompt
if prompt_key and prompt_key in SYSTEM_PROMPTS:
return SYSTEM_PROMPTS[prompt_key]
return None
async def start_up(self):
if not self.phone_mode:
await self.wait_for_args()
api_key, voice_name, prompt_key, custom_prompt = self.latest_args[1:]
# Prepare system instruction
system_instruction = self.prepare_system_instruction(prompt_key, custom_prompt)
else:
api_key, voice_name = None, "Puck"
system_instruction = None
client = genai.Client(
api_key=api_key or os.getenv("GEMINI_API_KEY"),
http_options={"api_version": "v1alpha"},
)
# Create config for the connection
config_kwargs = {
"response_modalities": ["AUDIO"],
"speech_config": SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
)
}
# Add system instruction if available
if system_instruction:
config_kwargs["system_instruction"] = {
"parts": [{"text": system_instruction}]
}
# Create the configuration
config = LiveConnectConfig(**config_kwargs)
try:
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
array = np.frombuffer(audio.data, dtype=np.int16)
self.output_queue.put_nowait((self.output_sample_rate, array))
except Exception as e:
print(f"Error in Gemini connection: {e}")
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
yield audio
except (asyncio.TimeoutError, TimeoutError):
pass
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
self.input_queue.put_nowait(audio_message)
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
stream = Stream(
modality="audio",
mode="send-receive",
handler=GeminiHandler(),
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
additional_inputs=[
gr.Textbox(
label="API Key",
type="password",
value=os.getenv("GEMINI_API_KEY") if not get_space() else "",
),
gr.Dropdown(
label="Voice",
choices=[
"Puck",
"Charon",
"Kore",
"Fenrir",
"Aoede",
],
value="Puck",
),
gr.Dropdown(
label="Preset Prompt",
choices=list(SYSTEM_PROMPTS.keys()),
value="Default",
),
gr.Textbox(
label="Custom Prompt",
placeholder="Enter a custom system prompt (overrides preset if not empty)",
value="",
),
],
)
class InputData(BaseModel):
webrtc_id: str
voice_name: str
api_key: str
prompt_key: str = ""
custom_prompt: str = ""
app = FastAPI()
stream.mount(app)
@app.post("/input_hook")
async def _(body: InputData):
stream.set_input(body.webrtc_id, body.api_key, body.voice_name, body.prompt_key, body.custom_prompt)
return {"status": "ok"}
@app.get("/")
async def index():
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = (current_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
html_content = html_content.replace("__SYSTEM_PROMPTS__", json.dumps(SYSTEM_PROMPTS))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)