Spaces:
Sleeping
Sleeping
import streamlit as st | |
from audio_recorder_streamlit import audio_recorder | |
import time | |
import google.generativeai as genai | |
from datetime import datetime | |
import json | |
from prompts import REAL_TIME_ANALYSIS_PROMPT, MI_SYSTEM_PROMPT | |
def show_live_session(): | |
st.title("Live Therapy Session Recording & Analysis") | |
# Initialize session state | |
if "recording" not in st.session_state: | |
st.session_state.recording = False | |
if "session_transcript" not in st.session_state: | |
st.session_state.session_transcript = [] | |
if "session_start_time" not in st.session_state: | |
st.session_state.session_start_time = None | |
# Layout | |
col1, col2 = st.columns([2, 3]) | |
with col1: | |
show_recording_controls() | |
show_session_info() | |
with col2: | |
show_real_time_analysis() | |
def show_recording_controls(): | |
st.subheader("Recording Controls") | |
# Start/Stop Recording button | |
if st.button("Start Recording" if not st.session_state.recording else "Stop Recording"): | |
if not st.session_state.recording: | |
start_session() | |
else: | |
end_session() | |
# Recording indicator | |
if st.session_state.recording: | |
st.markdown("🔴 **Recording in progress...**") | |
# Audio recorder | |
audio_bytes = audio_recorder() | |
if audio_bytes: | |
st.audio(audio_bytes, format="audio/wav") | |
process_audio(audio_bytes) | |
def start_session(): | |
st.session_state.recording = True | |
st.session_state.session_start_time = datetime.now() | |
st.session_state.session_transcript = [] | |
def end_session(): | |
st.session_state.recording = False | |
save_session() | |
def show_session_info(): | |
if st.session_state.recording and st.session_state.session_start_time: | |
duration = datetime.now() - st.session_state.session_start_time | |
st.info(f"Session Duration: {str(duration).split('.')[0]}") | |
def show_real_time_analysis(): | |
st.subheader("Real-time Analysis") | |
# Display transcript and analysis | |
for entry in st.session_state.session_transcript: | |
with st.expander(f"Entry at {entry['timestamp']}"): | |
st.markdown(f"**Speaker:** {entry['speaker']}") | |
st.markdown(entry['text']) | |
if 'analysis' in entry: | |
st.markdown("### Analysis") | |
st.markdown(entry['analysis']) | |
def process_audio(audio_bytes): | |
"""Process recorded audio""" | |
try: | |
# Here you would typically: | |
# 1. Convert audio_bytes to text using a speech-to-text service | |
# 2. Analyze the text using Gemini | |
# For now, we'll use a placeholder text | |
transcript = "Example transcription" # Replace with actual transcription | |
# Add to session transcript | |
entry = { | |
"speaker": "Client", | |
"text": transcript, | |
"timestamp": datetime.now().strftime("%H:%M:%S") | |
} | |
# Generate analysis | |
analysis = analyze_real_time(transcript) | |
if analysis: | |
entry["analysis"] = analysis | |
st.session_state.session_transcript.append(entry) | |
except Exception as e: | |
st.error(f"Error processing audio: {str(e)}") | |
def analyze_real_time(transcript): | |
try: | |
# Configure Gemini model | |
model = genai.GenerativeModel('gemini-pro') | |
# Prepare context | |
context = { | |
"transcript": transcript, | |
"session_history": str(st.session_state.session_transcript[-5:]), # Last 5 entries | |
"timestamp": datetime.now().strftime("%H:%M:%S") | |
} | |
# Generate analysis | |
prompt = f""" | |
Analyze the following therapy session segment using MI principles: | |
Transcript: {context['transcript']} | |
Recent Context: {context['session_history']} | |
Please provide: | |
1. Identification of MI techniques used or missed opportunities | |
2. Analysis of change talk vs sustain talk | |
3. Suggestions for next interventions | |
4. Overall MI adherence assessment | |
""" | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
st.error(f"Error generating analysis: {str(e)}") | |
return None | |
def save_session(): | |
"""Save session data to file""" | |
if st.session_state.session_transcript: | |
try: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"session_{timestamp}.json" | |
session_data = { | |
"start_time": st.session_state.session_start_time.isoformat(), | |
"end_time": datetime.now().isoformat(), | |
"transcript": st.session_state.session_transcript | |
} | |
with open(filename, "w") as f: | |
json.dump(session_data, f, indent=4) | |
st.success(f"Session saved to {filename}") | |
except Exception as e: | |
st.error(f"Error saving session: {str(e)}") | |
# Add session controls | |
def show_session_controls(): | |
st.sidebar.subheader("Session Controls") | |
# Session settings | |
st.sidebar.text_input("Client ID (optional)") | |
st.sidebar.text_input("Session Notes (optional)") | |
# Timer controls | |
if st.session_state.recording: | |
if st.sidebar.button("Add Marker"): | |
add_session_marker() | |
def add_session_marker(): | |
"""Add a marker/note to the session transcript""" | |
marker_text = st.text_input("Marker note:") | |
if marker_text: | |
st.session_state.session_transcript.append({ | |
"speaker": "System", | |
"text": f"MARKER: {marker_text}", | |
"timestamp": datetime.now().strftime("%H:%M:%S") | |
}) | |
# Add visualization features | |
def show_session_visualizations(): | |
if st.session_state.session_transcript: | |
st.subheader("Session Analytics") | |
# Add visualizations here (e.g., using plotly) | |
# - Speaking time distribution | |
# - Change talk vs sustain talk ratio | |
# - MI adherence scores | |
pass | |
def show_live_session_main(): | |
show_live_session() | |
show_session_controls() | |
show_session_visualizations() |