SpeechToText / app.py
nakhzai's picture
Update app.py
7f1fa89 verified
# ✅ STEP 5: Write the app.py file
import os
import io
import json
import time
import base64
import streamlit as st
from groq import Groq
from gtts import gTTS
from audiorecorder import audiorecorder
from faster_whisper import WhisperModel
from pathlib import Path
from pydub import AudioSegment
# Load environment variable for GROQ API KEY
GROQ_API_KEY = os.environ.get ("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("please set the GROQ_API_KEY environment variable")
# --------------------.get ()----
# Initialize session history & flags
# ------------------------
if "history" not in st.session_state:
st.session_state.history = []
if "last_answer" not in st.session_state:
st.session_state.last_answer = None
if "last_audio" not in st.session_state:
st.session_state.last_audio = None
if "just_generated" not in st.session_state:
st.session_state.just_generated = False
if "suppress_audio" not in st.session_state:
st.session_state.suppress_audio = False
if "playing" not in st.session_state:
st.session_state.playing = None
if "play_autoplay" not in st.session_state:
st.session_state.play_autoplay = False
if "selected_audio_path" not in st.session_state:
st.session_state.selected_audio_path = None
history = st.session_state.history # shorthand
# ------------------------
# Persistent Chat Memory
# ------------------------
def ensure_tag_dirs(tag):
base = Path("chat_data") / tag
(base / "uploads").mkdir(parents=True, exist_ok=True)
(base / "downloads").mkdir(parents=True, exist_ok=True)
return base
def get_chat_file(tag):
return ensure_tag_dirs(tag) / f"{tag}.json"
def load_chat(tag):
path = get_chat_file(tag)
if path.exists():
try:
return json.loads(path.read_text())
except Exception:
return []
return []
def save_chat(tag, history):
path = get_chat_file(tag)
tmp = path.with_suffix(".tmp")
tmp.write_text(json.dumps(history, ensure_ascii=False, indent=2))
tmp.replace(path)
# ------------------------
# Autoplay helper
# ------------------------
# ------------------------
# Audio player helper (manual play only)
# ------------------------
def autoplay_audio(file_path):
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
md = f"""
<audio controls>
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
"""
st.markdown(md, unsafe_allow_html=True)
# ------------------------
# Sidebar settings
# ------------------------
st.sidebar.title("🔧 Settings")
# =====================================
# Sidebar: Select Conversation Mode
# =====================================
st.sidebar.markdown("### ⚙️ Conversation Mode")
mode = st.sidebar.selectbox(
"Select Conversation Mode",
["voice_chat", "text_chat"],
key="chat_mode"
)
# Ensure chat_data folder for this mode
chat_tag = mode
if not st.session_state.history:
st.session_state.history.extend(load_chat(chat_tag))
st.markdown(
"<h2 style='text-align: center; color: #1E90FF;'>🎤 Speech ↔ Text ↔ Speech Chatbot</h2>",
unsafe_allow_html=True
)
groq_key = st.sidebar.text_input("GROQ_API_KEY", os.getenv("GROQ_API_KEY", ""), type="password")
groq_model = st.sidebar.selectbox(
"Groq Model",
["llama-3.1-8b-instant", "llama-3.3-70b-versatile", "meta-llama/llama-guard-4-12b"]
)
whisper_size = st.sidebar.selectbox("Whisper model size", ["tiny", "base", "small", "medium", "large-v3"], index=1)
whisper_dtype = st.sidebar.selectbox("Compute type", ["int8", "float16", "float32"], index=0)
tts_lang = st.sidebar.text_input("gTTS Language Code", "en")
speech_rate = st.sidebar.slider("Speech Rate (%)", 50, 150, 100, 5)
pitch_shift = st.sidebar.slider("Pitch Shift (semitones)", -5, 5, 0, 1)
# ------------------------
# Feedback Summary in Sidebar
# ------------------------
likes = sum(1 for m in history if m.get("feedback") == "like")
dislikes = sum(1 for m in history if m.get("feedback") == "dislike")
st.sidebar.markdown("### 📝 Feedback Summary")
st.sidebar.write(f"👍 Likes: {likes}")
st.sidebar.write(f"👎 Dislikes: {dislikes}")
if st.sidebar.button("🗑️ Clear Feedback Summary"):
for m in history:
if "feedback" in m:
m["feedback"] = None
save_chat(chat_tag, history)
st.session_state.playing = None
st.rerun()
# ------------------------
# Conversation history dropdown
# ------------------------
st.sidebar.markdown("### 💬 Conversation History")
questions = [m["content"] for m in history if m["role"] == "user"]
selected_q = st.sidebar.selectbox("Select a past question", [""] + questions, index=0)
# ------------------------
# Delete entire history
# ------------------------
if st.sidebar.button("🗑️ Delete History"):
st.session_state.confirm_delete = True
if st.session_state.get("confirm_delete", False):
st.sidebar.warning("⚠️ Are you sure you want to delete the entire history?")
col1, col2 = st.sidebar.columns([1, 1])
with col1:
if st.button("✅ Yes, delete"):
st.session_state.history.clear()
save_chat(chat_tag, st.session_state.history)
st.session_state.playing = None
st.session_state.last_answer = None
st.session_state.last_audio = None
st.session_state.selected_audio_path = None
st.session_state.confirm_delete = False
st.rerun()
with col2:
if st.button("❌ Cancel"):
st.session_state.confirm_delete = False
st.rerun()
# ======================================================
# MODE HANDLING
# ======================================================
if mode == "voice_chat":
st.markdown("<h3 style='text-align: center; color: #1E90FF;'>🎤 Voice Chat Mode</h3>", unsafe_allow_html=True)
# Recording + Upload
audio = audiorecorder("🎙️ Start Recording", "⏹️ Stop Recording")
uploaded = st.file_uploader("📂 Or upload an audio file", type=["wav", "mp3", "m4a"])
wav_path = None
if len(audio) > 0:
wav_path = ensure_tag_dirs(chat_tag) / "uploads" / "recorded.wav"
audio.export(wav_path, format="wav")
st.audio(str(wav_path))
elif uploaded:
wav_path = ensure_tag_dirs(chat_tag) / "uploads" / uploaded.name
with open(wav_path, "wb") as f:
f.write(uploaded.read())
st.audio(str(wav_path))
# Process Audio
if wav_path and st.button("🧠 Transcribe → Ask Groq → Speak"):
model = WhisperModel(whisper_size, device="cpu", compute_type=whisper_dtype)
segments, _ = model.transcribe(str(wav_path))
transcript = " ".join([seg.text for seg in segments])
st.success(f"🗣️ You said: {transcript}")
client = Groq(api_key=groq_key)
resp = client.chat.completions.create(
model=groq_model,
messages=[
{"role": "system", "content": "You are a helpful AI voice assistant."},
{"role": "user", "content": transcript},
],
)
answer = resp.choices[0].message.content
tts = gTTS(answer, lang=tts_lang)
buf = io.BytesIO()
tts.write_to_fp(buf)
buf.seek(0)
audio_seg = AudioSegment.from_file(buf, format="mp3")
if speech_rate != 100:
audio_seg = audio_seg._spawn(audio_seg.raw_data, overrides={
"frame_rate": int(audio_seg.frame_rate * (speech_rate / 100.0))
}).set_frame_rate(audio_seg.frame_rate)
if pitch_shift != 0:
new_sample_rate = int(audio_seg.frame_rate * (2.0 ** (pitch_shift / 12.0)))
audio_seg = audio_seg._spawn(audio_seg.raw_data, overrides={'frame_rate': new_sample_rate})
audio_seg = audio_seg.set_frame_rate(44100)
out_path = ensure_tag_dirs(chat_tag) / "downloads" / f"voice_{int(time.time())}.mp3"
audio_seg.export(out_path, format="mp3")
history.append({"role": "user", "content": transcript})
history.append({"role": "assistant", "content": answer, "audio": str(out_path)})
save_chat(chat_tag, history)
st.session_state.last_answer = answer
st.session_state.last_audio = str(out_path)
st.session_state.just_generated = True
st.session_state.suppress_audio = False
st.session_state.playing = None
st.session_state.selected_audio_path = None
elif mode == "text_chat":
st.markdown("<h3 style='text-align: center; color: #1E90FF;'>⌨️ Text Chat Mode</h3>", unsafe_allow_html=True)
with st.form("text_chat_form", clear_on_submit=True):
user_input = st.text_input("Type your question here:", key="text_chat_input")
submitted = st.form_submit_button("Send")
if submitted and user_input:
history.append({"role": "user", "content": user_input})
client = Groq(api_key=groq_key)
resp = client.chat.completions.create(
model=groq_model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": user_input},
],
)
assistant_reply = resp.choices[0].message.content
assistant_msg = {"role": "assistant", "content": assistant_reply}
# Save TTS audio
tts = gTTS(assistant_reply, lang=tts_lang)
out_path = ensure_tag_dirs(chat_tag) / "downloads" / f"text_{int(time.time())}.mp3"
tts.save(str(out_path))
assistant_msg["audio"] = str(out_path)
history.append(assistant_msg)
save_chat(chat_tag, history)
st.session_state.last_answer = assistant_reply
st.session_state.last_audio = str(out_path)
st.session_state.just_generated = True
st.session_state.suppress_audio = False
st.session_state.playing = None
st.session_state.selected_audio_path = None
st.rerun()
# ======================================================
# 3️⃣ Speech Output
# ======================================================
st.subheader("🔊 Speech Output")
if st.session_state.last_audio:
if st.session_state.suppress_audio:
autoplay_audio(st.session_state.last_audio)
st.session_state.suppress_audio = False
elif st.session_state.just_generated:
autoplay_audio(st.session_state.last_audio)
st.session_state.just_generated = False
else:
autoplay_audio(st.session_state.last_audio)
# ======================================================
# 4️⃣ Selected Conversation
# ======================================================
if selected_q:
st.markdown("<h3 style='text-align: center; color: #2E86C1;'>📜 Selected Conversation</h3>", unsafe_allow_html=True)
st.markdown(f"**You:** {selected_q}")
for idx, msg in enumerate(history):
if msg["role"] == "user" and msg["content"] == selected_q:
if idx + 1 < len(history) and history[idx + 1]["role"] == "assistant":
answer_msg = history[idx + 1]
st.markdown(f"**Assistant:** {answer_msg['content']}")
if st.button("🗑️ Delete This Conversation", key=f"delete_selected_{idx}"):
del history[idx:idx+2]
save_chat(chat_tag, history)
st.session_state.playing = None
st.session_state.selected_audio_path = None
st.rerun()
if "audio" in answer_msg:
if st.button("▶️ Play Audio (Selected)", key=f"play_selected_{idx}"):
st.session_state.playing = None
st.session_state.suppress_audio = True
st.session_state.selected_audio_path = answer_msg["audio"]
st.rerun()
if st.session_state.selected_audio_path == answer_msg["audio"]:
with open(answer_msg["audio"], "rb") as f:
b64 = base64.b64encode(f.read()).decode()
md = f"""
<audio controls>
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
"""
st.markdown(md, unsafe_allow_html=True)
break
# ======================================================
# 5️⃣ Latest Conversation
# ======================================================
st.markdown("<h3 style='text-align: center; color: #117A65;'>💬 Latest Conversation</h3>", unsafe_allow_html=True)
if len(history) >= 2 and history[-2]["role"] == "user" and history[-1]["role"] == "assistant":
user_msg = history[-2]
assistant_msg = history[-1]
st.markdown(f"**You:** {user_msg['content']}")
st.markdown(f"**Assistant:** {assistant_msg['content']}")
if "audio" in assistant_msg:
col1, col2 = st.columns([1, 1])
with col1:
if st.button("▶️ Play Audio", key="play_latest"):
st.session_state.playing = assistant_msg["audio"]
st.session_state.selected_audio_path = None
st.rerun()
if st.session_state.playing == assistant_msg["audio"]:
with open(assistant_msg["audio"], "rb") as f:
b64 = base64.b64encode(f.read()).decode()
st.markdown(
f"""
<audio controls>
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
""",
unsafe_allow_html=True,
)
with col2:
if st.button("🗑️ Delete This Q+A", key="delete_latest"):
del history[-2:]
save_chat(chat_tag, history)
st.session_state.playing = None
st.rerun()
if "feedback" not in assistant_msg:
assistant_msg["feedback"] = None
col1, col2 = st.columns([1, 1])
with col1:
if st.button("👍 Like", key="like_latest"):
assistant_msg["feedback"] = "like"
save_chat(chat_tag, history)
st.rerun()
with col2:
if st.button("👎 Dislike", key="dislike_latest"):
assistant_msg["feedback"] = "dislike"
save_chat(chat_tag, history)
st.rerun()
if assistant_msg["feedback"] == "like":
st.success("You liked this response 👍")
elif assistant_msg["feedback"] == "dislike":
st.error("You disliked this response 👎")
st.markdown("---")