|
import streamlit as st |
|
from openai import OpenAI |
|
import sounddevice as sd |
|
import scipy.io.wavfile |
|
import io |
|
import base64 |
|
import os |
|
import time |
|
|
|
|
|
st.set_page_config(page_title="Voice Bot", layout="wide") |
|
|
|
|
|
SAMPLE_RATE = 44100 |
|
RECORD_DURATION = 5 |
|
TEMP_AUDIO_FILE = "temp_audio.wav" |
|
|
|
|
|
api_key = st.secrets['openai'] |
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
if 'recorded_audio' not in st.session_state: |
|
st.session_state.recorded_audio = None |
|
if 'user_text' not in st.session_state: |
|
st.session_state.user_text = None |
|
if 'ai_reply' not in st.session_state: |
|
st.session_state.ai_reply = None |
|
|
|
def load_context(): |
|
"""Load the context from file.""" |
|
try: |
|
with open("context.txt", "r") as f: |
|
return f.read() |
|
except FileNotFoundError: |
|
st.error("Context file not found!") |
|
return "" |
|
|
|
def record_audio(): |
|
"""Record audio and return the buffer.""" |
|
progress_bar = st.progress(0) |
|
recording = sd.rec(int(RECORD_DURATION * SAMPLE_RATE), |
|
samplerate=SAMPLE_RATE, |
|
channels=1) |
|
|
|
|
|
for i in range(RECORD_DURATION * 10): |
|
progress_bar.progress((i + 1) / (RECORD_DURATION * 10)) |
|
time.sleep(0.1) |
|
|
|
sd.wait() |
|
progress_bar.empty() |
|
|
|
buf = io.BytesIO() |
|
scipy.io.wavfile.write(buf, SAMPLE_RATE, recording) |
|
buf.seek(0) |
|
return buf |
|
|
|
def transcribe_audio(audio_buffer): |
|
"""Transcribe audio using Whisper API.""" |
|
with open(TEMP_AUDIO_FILE, "wb") as f: |
|
f.write(audio_buffer.getvalue()) |
|
|
|
with open(TEMP_AUDIO_FILE, "rb") as audio_file: |
|
transcript = client.audio.transcriptions.create( |
|
model="whisper-1", |
|
file=audio_file |
|
) |
|
return transcript.text |
|
|
|
def get_ai_response(user_text, context): |
|
"""Get AI response using GPT-4.""" |
|
system_prompt = f""" |
|
You are Prakhar. |
|
You must respond **only using the following context**: |
|
|
|
{context} |
|
|
|
If the user's question cannot be answered using this context, respond with: |
|
"I'm not sure about that based on what I know." |
|
""" |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_text} |
|
] |
|
) |
|
return response.choices[0].message.content |
|
|
|
def text_to_speech(text): |
|
"""Convert text to speech using OpenAI TTS.""" |
|
speech = client.audio.speech.create( |
|
model="tts-1", |
|
voice="onyx", |
|
input=text |
|
) |
|
return base64.b64encode(speech.content).decode() |
|
|
|
def handle_record_button(): |
|
"""Handle recording button click""" |
|
st.session_state.processing = True |
|
info_placeholder = st.empty() |
|
info_placeholder.info("Recording...") |
|
audio_buffer = record_audio() |
|
info_placeholder.empty() |
|
st.session_state.recorded_audio = audio_buffer |
|
|
|
def main(): |
|
st.title("Voice Bot") |
|
|
|
if 'context' not in st.session_state: |
|
st.session_state.context = load_context() |
|
if 'processing' not in st.session_state: |
|
st.session_state.processing = False |
|
|
|
with st.container(): |
|
|
|
audio, script = st.columns(2, border=True) |
|
|
|
with audio: |
|
st.subheader("Audio Input") |
|
st.button("ποΈ Record Voice", on_click=handle_record_button) |
|
|
|
|
|
process_placeholder = st.empty() |
|
|
|
|
|
if st.session_state.processing: |
|
with process_placeholder.container(): |
|
with st.spinner("Processing..."): |
|
st.session_state.user_text = transcribe_audio(st.session_state.recorded_audio) |
|
st.session_state.ai_reply = get_ai_response(st.session_state.user_text, st.session_state.context) |
|
audio_b64 = text_to_speech(st.session_state.ai_reply) |
|
st.session_state.ai_audio = audio_b64 |
|
st.session_state.processing = False |
|
|
|
|
|
if st.session_state.recorded_audio is not None: |
|
st.audio(st.session_state.recorded_audio, format="audio/wav") |
|
if hasattr(st.session_state, 'ai_audio'): |
|
st.audio(f"data:audio/mp3;base64,{st.session_state.ai_audio}", format="audio/mp3") |
|
|
|
with script: |
|
st.subheader("Conversation") |
|
if st.session_state.user_text is not None: |
|
st.markdown("**You said:**") |
|
st.markdown(f"{st.session_state.user_text}") |
|
st.markdown("**AI Response:**") |
|
st.markdown(f"{st.session_state.ai_reply}") |
|
|
|
st.divider() |
|
|
|
with st.container(border=True): |
|
st.text_area("Context", value=st.session_state.context, height=270, disabled=False) |
|
st.markdown("You can update the context in the `context.txt` file.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|