TherapyNote / app.py
abagherp's picture
Upload folder using huggingface_hub
6830eb0 verified
from __future__ import annotations
import os
from pathlib import Path
import yaml
import gradio as gr
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from config.settings import settings
from forms.schemas import (
SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote,
FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote,
POMRNote, NarrativeNote, CBENote, SBARNote
)
from utils.youtube import download_transcript
from utils.youtube import extract_youtube_video_id
from utils.text_processing import chunk_text
from utils.audio import transcribe_audio
from models.llm_provider import get_llm, get_model_identifier
from utils.cache import CacheManager
from config.auth import load_auth_credentials
# Dictionary mapping form types to their schemas
FORM_SCHEMAS = {
"SOAP": SOAPNote,
"DAP": DAPNote,
"BIRP": BIRPNote,
"PIRP": PIRPNote,
"GIRP": GIRPNote,
"SIRP": SIRPNote,
"FAIR/F-DARP": FAIRFDARPNote,
"DARE": DARENote,
"PIE": PIENote,
"SOAPIER": SOAPIERNote,
"SOAPIE": SOAPIENote,
"POMR": POMRNote,
"Narrative": NarrativeNote,
"CBE": CBENote,
"SBAR": SBARNote,
}
# Initialize cache manager
cache_manager = CacheManager()
def load_prompt(note_type: str) -> tuple[str, str]:
"""Load the prompt template from YAML for the specified note type."""
prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml")
with open(prompt_path, "r") as f:
data = yaml.safe_load(f)
note_prompts = data.get("prompts", {}).get(note_type.lower())
if not note_prompts:
raise ValueError(f"No prompt template found for note type: {note_type}")
return note_prompts["system"], note_prompts["human"]
def process_input(
input_text: str,
form_type: str,
input_type: str = "text",
audio_file: str | None = None,
force_refresh: bool = False
) -> str:
"""Process input (text, YouTube URL, or audio) and generate notes."""
try:
# Get transcript based on input type
if input_type == "audio" and audio_file:
print("Processing audio file...")
transcript = transcribe_audio(audio_file)
elif "youtube.com" in input_text or "youtu.be" in input_text:
print(f"Downloading transcript from YouTube...")
video_id = extract_youtube_video_id(input_text)
# Check cache first
if not force_refresh:
cached_transcript = cache_manager.get_transcript(video_id)
if cached_transcript:
print("Using cached transcript...")
transcript = cached_transcript
else:
transcript = download_transcript(input_text)
cache_manager.store_transcript(video_id, transcript)
else:
transcript = download_transcript(input_text)
cache_manager.store_transcript(video_id, transcript)
else:
print("Using provided text directly...")
transcript = input_text
# Initialize LLM
llm = get_llm()
model_id = get_model_identifier(llm)
# Check extraction cache
if not force_refresh:
cached_result = cache_manager.get_extraction(
transcript,
form_type.lower(),
model_id
)
if cached_result:
print("Using cached extraction result...")
formatted_response = yaml.dump(
cached_result,
default_flow_style=False,
sort_keys=False
)
return f"## {form_type} Note:\n```yaml\n{formatted_response}\n```"
# Get schema for selected form type
schema = FORM_SCHEMAS.get(form_type)
if not schema:
return f"Error: Unsupported form type {form_type}"
# Create structured LLM
structured_llm = llm.with_structured_output(schema=schema)
# Load prompts
system_prompt, human_prompt = load_prompt(form_type.lower())
# Create prompt template
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", human_prompt)
])
# Process transcript
print(f"Generating {form_type} note...")
response = structured_llm.invoke(transcript)
# Store result in cache
result_dict = response.model_dump(exclude_unset=False, exclude_none=False)
cache_manager.store_extraction(
transcript,
form_type.lower(),
result_dict,
model_id
)
# Format the response
formatted_response = yaml.dump(
result_dict,
default_flow_style=False,
sort_keys=False
)
return f"## {form_type} Note:\n```yaml\n{formatted_response}\n```"
except Exception as e:
return f"Error: {str(e)}"
def create_ui() -> gr.Blocks:
"""Create the Gradio interface."""
# Load authorized users from config
auth = load_auth_credentials()
def check_auth(username: str, password: str) -> bool:
"""Check if username and password are valid."""
return username in auth and auth[username] == password
with gr.Blocks(title="Therapy Note Generator") as demo:
# Login interface
with gr.Row():
with gr.Column():
username = gr.Textbox(label="Username")
password = gr.Textbox(label="Password", type="password")
login_btn = gr.Button("Login")
login_msg = gr.Markdown()
# Main interface (initially invisible)
with gr.Column(visible=False) as main_interface:
gr.Markdown("# Therapy Note Generator")
gr.Markdown("""
Enter a YouTube URL, paste a transcript directly, or upload an audio file.
Select the desired note format and click 'Generate' to create a structured note.
""")
with gr.Row():
with gr.Column():
# Input type selector
input_type = gr.Radio(
choices=["text", "youtube", "audio"],
value="text",
label="Input Type",
info="Choose how you want to provide the therapy session"
)
# Text input for transcript or YouTube URL
input_text = gr.Textbox(
label="Text Input",
placeholder="Enter transcript or YouTube URL here...",
lines=10,
visible=True
)
# Audio upload
audio_input = gr.Audio(
label="Audio Input",
type="filepath",
visible=False
)
# Note format selector
form_type = gr.Dropdown(
choices=list(FORM_SCHEMAS.keys()),
value="SOAP",
label="Note Format"
)
generate_btn = gr.Button("Generate Note", variant="primary")
with gr.Column():
# Transcript output
transcript_output = gr.Textbox(
label="Generated Transcript",
lines=10,
visible=False,
interactive=False
)
# Structured note output
note_output = gr.Markdown(label="Generated Note")
# Update visibility based on input type
def update_inputs(choice):
return {
input_text: gr.update(visible=choice in ["text", "youtube"]),
audio_input: gr.update(visible=choice == "audio"),
transcript_output: gr.update(visible=choice in ["youtube", "audio"])
}
input_type.change(
fn=update_inputs,
inputs=input_type,
outputs=[input_text, audio_input, transcript_output]
)
def process_and_show_transcript(
input_text: str,
form_type: str,
input_type: str = "text",
audio_file: str | None = None,
force_refresh: bool = False
) -> tuple[str, str]:
"""Process input and return both transcript and structured note."""
try:
# Get transcript based on input type
if input_type == "audio" and audio_file:
print("Processing audio file...")
transcript = transcribe_audio(audio_file)
elif "youtube.com" in input_text or "youtu.be" in input_text:
print(f"Downloading transcript from YouTube...")
video_id = extract_youtube_video_id(input_text)
# Check cache first
if not force_refresh:
cached_transcript = cache_manager.get_transcript(video_id)
if cached_transcript:
print("Using cached transcript...")
transcript = cached_transcript
else:
transcript = download_transcript(input_text)
cache_manager.store_transcript(video_id, transcript)
else:
transcript = download_transcript(input_text)
cache_manager.store_transcript(video_id, transcript)
else:
print("Using provided text directly...")
transcript = input_text
# Process the transcript to generate the note
note_output = process_input(input_text, form_type, input_type, audio_file, force_refresh)
return transcript, note_output
except Exception as e:
error_msg = f"Error: {str(e)}"
return error_msg, error_msg
# Handle generate button click
generate_btn.click(
fn=process_and_show_transcript,
inputs=[input_text, form_type, input_type, audio_input],
outputs=[transcript_output, note_output]
)
# Example inputs
try:
with open("data/sample_note.txt", "r") as f:
sample_text = f.read()
except FileNotFoundError:
sample_text = "Sample therapy session transcript..."
gr.Examples(
examples=[
# Text example
[sample_text, "SOAP", "text", None],
# YouTube examples
["https://www.youtube.com/watch?v=KuHLL2AE-SE", "DAP", "youtube", None],
["https://www.youtube.com/watch?v=jS1KE3_Pqlc", "SOAPIER", "youtube", None],
# Audio example
[None, "BIRP", "audio", "data/CBT Role-Play.mp3"]
],
inputs=[input_text, form_type, input_type, audio_input],
outputs=[transcript_output, note_output],
fn=process_and_show_transcript,
cache_examples=False,
label="Example Inputs",
examples_per_page=4
)
def login(username: str, password: str):
"""Handle login and return updates for UI components."""
if check_auth(username, password):
return [
gr.update(visible=True), # main_interface
gr.update(value="✅ Login successful!", visible=True), # login_msg
gr.update(visible=False), # username
gr.update(visible=False), # password
gr.update(visible=False), # login_btn
]
else:
return [
gr.update(visible=False), # main_interface
gr.update(value="❌ Invalid credentials", visible=True), # login_msg
gr.update(), # username - no change
gr.update(), # password - no change
gr.update(), # login_btn - no change
]
login_btn.click(
fn=login,
inputs=[username, password],
outputs=[main_interface, login_msg, username, password, login_btn]
)
return demo
if __name__ == "__main__":
# Clean up any existing Gradio cache
cache_manager.cleanup_gradio_cache()
demo = create_ui()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
auth=None # We're using our own auth system instead of Gradio's
)