Spaces:
Sleeping
Sleeping
# ✅ STEP 5: Write the app.py file | |
import os | |
import io | |
import json | |
import time | |
import base64 | |
import streamlit as st | |
from groq import Groq | |
from gtts import gTTS | |
from audiorecorder import audiorecorder | |
from faster_whisper import WhisperModel | |
from pathlib import Path | |
from pydub import AudioSegment | |
# Load environment variable for GROQ API KEY | |
GROQ_API_KEY = os.environ.get ("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
raise ValueError("please set the GROQ_API_KEY environment variable") | |
# --------------------.get ()---- | |
# Initialize session history & flags | |
# ------------------------ | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "last_answer" not in st.session_state: | |
st.session_state.last_answer = None | |
if "last_audio" not in st.session_state: | |
st.session_state.last_audio = None | |
if "just_generated" not in st.session_state: | |
st.session_state.just_generated = False | |
if "suppress_audio" not in st.session_state: | |
st.session_state.suppress_audio = False | |
if "playing" not in st.session_state: | |
st.session_state.playing = None | |
if "play_autoplay" not in st.session_state: | |
st.session_state.play_autoplay = False | |
if "selected_audio_path" not in st.session_state: | |
st.session_state.selected_audio_path = None | |
history = st.session_state.history # shorthand | |
# ------------------------ | |
# Persistent Chat Memory | |
# ------------------------ | |
def ensure_tag_dirs(tag): | |
base = Path("chat_data") / tag | |
(base / "uploads").mkdir(parents=True, exist_ok=True) | |
(base / "downloads").mkdir(parents=True, exist_ok=True) | |
return base | |
def get_chat_file(tag): | |
return ensure_tag_dirs(tag) / f"{tag}.json" | |
def load_chat(tag): | |
path = get_chat_file(tag) | |
if path.exists(): | |
try: | |
return json.loads(path.read_text()) | |
except Exception: | |
return [] | |
return [] | |
def save_chat(tag, history): | |
path = get_chat_file(tag) | |
tmp = path.with_suffix(".tmp") | |
tmp.write_text(json.dumps(history, ensure_ascii=False, indent=2)) | |
tmp.replace(path) | |
# ------------------------ | |
# Autoplay helper | |
# ------------------------ | |
# ------------------------ | |
# Audio player helper (manual play only) | |
# ------------------------ | |
def autoplay_audio(file_path): | |
with open(file_path, "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
md = f""" | |
<audio controls> | |
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3"> | |
</audio> | |
""" | |
st.markdown(md, unsafe_allow_html=True) | |
# ------------------------ | |
# Sidebar settings | |
# ------------------------ | |
st.sidebar.title("🔧 Settings") | |
# ===================================== | |
# Sidebar: Select Conversation Mode | |
# ===================================== | |
st.sidebar.markdown("### ⚙️ Conversation Mode") | |
mode = st.sidebar.selectbox( | |
"Select Conversation Mode", | |
["voice_chat", "text_chat"], | |
key="chat_mode" | |
) | |
# Ensure chat_data folder for this mode | |
chat_tag = mode | |
if not st.session_state.history: | |
st.session_state.history.extend(load_chat(chat_tag)) | |
st.markdown( | |
"<h2 style='text-align: center; color: #1E90FF;'>🎤 Speech ↔ Text ↔ Speech Chatbot</h2>", | |
unsafe_allow_html=True | |
) | |
groq_key = st.sidebar.text_input("GROQ_API_KEY", os.getenv("GROQ_API_KEY", ""), type="password") | |
groq_model = st.sidebar.selectbox( | |
"Groq Model", | |
["llama-3.1-8b-instant", "llama-3.3-70b-versatile", "meta-llama/llama-guard-4-12b"] | |
) | |
whisper_size = st.sidebar.selectbox("Whisper model size", ["tiny", "base", "small", "medium", "large-v3"], index=1) | |
whisper_dtype = st.sidebar.selectbox("Compute type", ["int8", "float16", "float32"], index=0) | |
tts_lang = st.sidebar.text_input("gTTS Language Code", "en") | |
speech_rate = st.sidebar.slider("Speech Rate (%)", 50, 150, 100, 5) | |
pitch_shift = st.sidebar.slider("Pitch Shift (semitones)", -5, 5, 0, 1) | |
# ------------------------ | |
# Feedback Summary in Sidebar | |
# ------------------------ | |
likes = sum(1 for m in history if m.get("feedback") == "like") | |
dislikes = sum(1 for m in history if m.get("feedback") == "dislike") | |
st.sidebar.markdown("### 📝 Feedback Summary") | |
st.sidebar.write(f"👍 Likes: {likes}") | |
st.sidebar.write(f"👎 Dislikes: {dislikes}") | |
if st.sidebar.button("🗑️ Clear Feedback Summary"): | |
for m in history: | |
if "feedback" in m: | |
m["feedback"] = None | |
save_chat(chat_tag, history) | |
st.session_state.playing = None | |
st.rerun() | |
# ------------------------ | |
# Conversation history dropdown | |
# ------------------------ | |
st.sidebar.markdown("### 💬 Conversation History") | |
questions = [m["content"] for m in history if m["role"] == "user"] | |
selected_q = st.sidebar.selectbox("Select a past question", [""] + questions, index=0) | |
# ------------------------ | |
# Delete entire history | |
# ------------------------ | |
if st.sidebar.button("🗑️ Delete History"): | |
st.session_state.confirm_delete = True | |
if st.session_state.get("confirm_delete", False): | |
st.sidebar.warning("⚠️ Are you sure you want to delete the entire history?") | |
col1, col2 = st.sidebar.columns([1, 1]) | |
with col1: | |
if st.button("✅ Yes, delete"): | |
st.session_state.history.clear() | |
save_chat(chat_tag, st.session_state.history) | |
st.session_state.playing = None | |
st.session_state.last_answer = None | |
st.session_state.last_audio = None | |
st.session_state.selected_audio_path = None | |
st.session_state.confirm_delete = False | |
st.rerun() | |
with col2: | |
if st.button("❌ Cancel"): | |
st.session_state.confirm_delete = False | |
st.rerun() | |
# ====================================================== | |
# MODE HANDLING | |
# ====================================================== | |
if mode == "voice_chat": | |
st.markdown("<h3 style='text-align: center; color: #1E90FF;'>🎤 Voice Chat Mode</h3>", unsafe_allow_html=True) | |
# Recording + Upload | |
audio = audiorecorder("🎙️ Start Recording", "⏹️ Stop Recording") | |
uploaded = st.file_uploader("📂 Or upload an audio file", type=["wav", "mp3", "m4a"]) | |
wav_path = None | |
if len(audio) > 0: | |
wav_path = ensure_tag_dirs(chat_tag) / "uploads" / "recorded.wav" | |
audio.export(wav_path, format="wav") | |
st.audio(str(wav_path)) | |
elif uploaded: | |
wav_path = ensure_tag_dirs(chat_tag) / "uploads" / uploaded.name | |
with open(wav_path, "wb") as f: | |
f.write(uploaded.read()) | |
st.audio(str(wav_path)) | |
# Process Audio | |
if wav_path and st.button("🧠 Transcribe → Ask Groq → Speak"): | |
model = WhisperModel(whisper_size, device="cpu", compute_type=whisper_dtype) | |
segments, _ = model.transcribe(str(wav_path)) | |
transcript = " ".join([seg.text for seg in segments]) | |
st.success(f"🗣️ You said: {transcript}") | |
client = Groq(api_key=groq_key) | |
resp = client.chat.completions.create( | |
model=groq_model, | |
messages=[ | |
{"role": "system", "content": "You are a helpful AI voice assistant."}, | |
{"role": "user", "content": transcript}, | |
], | |
) | |
answer = resp.choices[0].message.content | |
tts = gTTS(answer, lang=tts_lang) | |
buf = io.BytesIO() | |
tts.write_to_fp(buf) | |
buf.seek(0) | |
audio_seg = AudioSegment.from_file(buf, format="mp3") | |
if speech_rate != 100: | |
audio_seg = audio_seg._spawn(audio_seg.raw_data, overrides={ | |
"frame_rate": int(audio_seg.frame_rate * (speech_rate / 100.0)) | |
}).set_frame_rate(audio_seg.frame_rate) | |
if pitch_shift != 0: | |
new_sample_rate = int(audio_seg.frame_rate * (2.0 ** (pitch_shift / 12.0))) | |
audio_seg = audio_seg._spawn(audio_seg.raw_data, overrides={'frame_rate': new_sample_rate}) | |
audio_seg = audio_seg.set_frame_rate(44100) | |
out_path = ensure_tag_dirs(chat_tag) / "downloads" / f"voice_{int(time.time())}.mp3" | |
audio_seg.export(out_path, format="mp3") | |
history.append({"role": "user", "content": transcript}) | |
history.append({"role": "assistant", "content": answer, "audio": str(out_path)}) | |
save_chat(chat_tag, history) | |
st.session_state.last_answer = answer | |
st.session_state.last_audio = str(out_path) | |
st.session_state.just_generated = True | |
st.session_state.suppress_audio = False | |
st.session_state.playing = None | |
st.session_state.selected_audio_path = None | |
elif mode == "text_chat": | |
st.markdown("<h3 style='text-align: center; color: #1E90FF;'>⌨️ Text Chat Mode</h3>", unsafe_allow_html=True) | |
with st.form("text_chat_form", clear_on_submit=True): | |
user_input = st.text_input("Type your question here:", key="text_chat_input") | |
submitted = st.form_submit_button("Send") | |
if submitted and user_input: | |
history.append({"role": "user", "content": user_input}) | |
client = Groq(api_key=groq_key) | |
resp = client.chat.completions.create( | |
model=groq_model, | |
messages=[ | |
{"role": "system", "content": "You are a helpful AI assistant."}, | |
{"role": "user", "content": user_input}, | |
], | |
) | |
assistant_reply = resp.choices[0].message.content | |
assistant_msg = {"role": "assistant", "content": assistant_reply} | |
# Save TTS audio | |
tts = gTTS(assistant_reply, lang=tts_lang) | |
out_path = ensure_tag_dirs(chat_tag) / "downloads" / f"text_{int(time.time())}.mp3" | |
tts.save(str(out_path)) | |
assistant_msg["audio"] = str(out_path) | |
history.append(assistant_msg) | |
save_chat(chat_tag, history) | |
st.session_state.last_answer = assistant_reply | |
st.session_state.last_audio = str(out_path) | |
st.session_state.just_generated = True | |
st.session_state.suppress_audio = False | |
st.session_state.playing = None | |
st.session_state.selected_audio_path = None | |
st.rerun() | |
# ====================================================== | |
# 3️⃣ Speech Output | |
# ====================================================== | |
st.subheader("🔊 Speech Output") | |
if st.session_state.last_audio: | |
if st.session_state.suppress_audio: | |
autoplay_audio(st.session_state.last_audio) | |
st.session_state.suppress_audio = False | |
elif st.session_state.just_generated: | |
autoplay_audio(st.session_state.last_audio) | |
st.session_state.just_generated = False | |
else: | |
autoplay_audio(st.session_state.last_audio) | |
# ====================================================== | |
# 4️⃣ Selected Conversation | |
# ====================================================== | |
if selected_q: | |
st.markdown("<h3 style='text-align: center; color: #2E86C1;'>📜 Selected Conversation</h3>", unsafe_allow_html=True) | |
st.markdown(f"**You:** {selected_q}") | |
for idx, msg in enumerate(history): | |
if msg["role"] == "user" and msg["content"] == selected_q: | |
if idx + 1 < len(history) and history[idx + 1]["role"] == "assistant": | |
answer_msg = history[idx + 1] | |
st.markdown(f"**Assistant:** {answer_msg['content']}") | |
if st.button("🗑️ Delete This Conversation", key=f"delete_selected_{idx}"): | |
del history[idx:idx+2] | |
save_chat(chat_tag, history) | |
st.session_state.playing = None | |
st.session_state.selected_audio_path = None | |
st.rerun() | |
if "audio" in answer_msg: | |
if st.button("▶️ Play Audio (Selected)", key=f"play_selected_{idx}"): | |
st.session_state.playing = None | |
st.session_state.suppress_audio = True | |
st.session_state.selected_audio_path = answer_msg["audio"] | |
st.rerun() | |
if st.session_state.selected_audio_path == answer_msg["audio"]: | |
with open(answer_msg["audio"], "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
md = f""" | |
<audio controls> | |
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3"> | |
</audio> | |
""" | |
st.markdown(md, unsafe_allow_html=True) | |
break | |
# ====================================================== | |
# 5️⃣ Latest Conversation | |
# ====================================================== | |
st.markdown("<h3 style='text-align: center; color: #117A65;'>💬 Latest Conversation</h3>", unsafe_allow_html=True) | |
if len(history) >= 2 and history[-2]["role"] == "user" and history[-1]["role"] == "assistant": | |
user_msg = history[-2] | |
assistant_msg = history[-1] | |
st.markdown(f"**You:** {user_msg['content']}") | |
st.markdown(f"**Assistant:** {assistant_msg['content']}") | |
if "audio" in assistant_msg: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
if st.button("▶️ Play Audio", key="play_latest"): | |
st.session_state.playing = assistant_msg["audio"] | |
st.session_state.selected_audio_path = None | |
st.rerun() | |
if st.session_state.playing == assistant_msg["audio"]: | |
with open(assistant_msg["audio"], "rb") as f: | |
b64 = base64.b64encode(f.read()).decode() | |
st.markdown( | |
f""" | |
<audio controls> | |
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3"> | |
</audio> | |
""", | |
unsafe_allow_html=True, | |
) | |
with col2: | |
if st.button("🗑️ Delete This Q+A", key="delete_latest"): | |
del history[-2:] | |
save_chat(chat_tag, history) | |
st.session_state.playing = None | |
st.rerun() | |
if "feedback" not in assistant_msg: | |
assistant_msg["feedback"] = None | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
if st.button("👍 Like", key="like_latest"): | |
assistant_msg["feedback"] = "like" | |
save_chat(chat_tag, history) | |
st.rerun() | |
with col2: | |
if st.button("👎 Dislike", key="dislike_latest"): | |
assistant_msg["feedback"] = "dislike" | |
save_chat(chat_tag, history) | |
st.rerun() | |
if assistant_msg["feedback"] == "like": | |
st.success("You liked this response 👍") | |
elif assistant_msg["feedback"] == "dislike": | |
st.error("You disliked this response 👎") | |
st.markdown("---") | |