Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
from pathlib import Path | |
import yaml | |
import gradio as gr | |
from typing import Optional | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from config.settings import settings | |
from forms.schemas import ( | |
SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote, | |
FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote, | |
POMRNote, NarrativeNote, CBENote, SBARNote | |
) | |
from utils.youtube import download_transcript | |
from utils.youtube import extract_youtube_video_id | |
from utils.text_processing import chunk_text | |
from utils.audio import transcribe_audio | |
from models.llm_provider import get_llm, get_model_identifier | |
from utils.cache import CacheManager | |
from config.auth import load_auth_credentials | |
# Dictionary mapping form types to their schemas | |
FORM_SCHEMAS = { | |
"SOAP": SOAPNote, | |
"DAP": DAPNote, | |
"BIRP": BIRPNote, | |
"PIRP": PIRPNote, | |
"GIRP": GIRPNote, | |
"SIRP": SIRPNote, | |
"FAIR/F-DARP": FAIRFDARPNote, | |
"DARE": DARENote, | |
"PIE": PIENote, | |
"SOAPIER": SOAPIERNote, | |
"SOAPIE": SOAPIENote, | |
"POMR": POMRNote, | |
"Narrative": NarrativeNote, | |
"CBE": CBENote, | |
"SBAR": SBARNote, | |
} | |
# Initialize cache manager | |
cache_manager = CacheManager() | |
def load_prompt(note_type: str) -> tuple[str, str]: | |
"""Load the prompt template from YAML for the specified note type.""" | |
prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml") | |
with open(prompt_path, "r") as f: | |
data = yaml.safe_load(f) | |
note_prompts = data.get("prompts", {}).get(note_type.lower()) | |
if not note_prompts: | |
raise ValueError(f"No prompt template found for note type: {note_type}") | |
return note_prompts["system"], note_prompts["human"] | |
def process_input( | |
input_text: str, | |
form_type: str, | |
input_type: str = "text", | |
audio_file: str | None = None, | |
force_refresh: bool = False | |
) -> str: | |
"""Process input (text, YouTube URL, or audio) and generate notes.""" | |
try: | |
# Get transcript based on input type | |
if input_type == "audio" and audio_file: | |
print("Processing audio file...") | |
transcript = transcribe_audio(audio_file) | |
elif "youtube.com" in input_text or "youtu.be" in input_text: | |
print(f"Downloading transcript from YouTube...") | |
video_id = extract_youtube_video_id(input_text) | |
# Check cache first | |
if not force_refresh: | |
cached_transcript = cache_manager.get_transcript(video_id) | |
if cached_transcript: | |
print("Using cached transcript...") | |
transcript = cached_transcript | |
else: | |
transcript = download_transcript(input_text) | |
cache_manager.store_transcript(video_id, transcript) | |
else: | |
transcript = download_transcript(input_text) | |
cache_manager.store_transcript(video_id, transcript) | |
else: | |
print("Using provided text directly...") | |
transcript = input_text | |
# Initialize LLM | |
llm = get_llm() | |
model_id = get_model_identifier(llm) | |
# Check extraction cache | |
if not force_refresh: | |
cached_result = cache_manager.get_extraction( | |
transcript, | |
form_type.lower(), | |
model_id | |
) | |
if cached_result: | |
print("Using cached extraction result...") | |
formatted_response = yaml.dump( | |
cached_result, | |
default_flow_style=False, | |
sort_keys=False | |
) | |
return f"## {form_type} Note:\n```yaml\n{formatted_response}\n```" | |
# Get schema for selected form type | |
schema = FORM_SCHEMAS.get(form_type) | |
if not schema: | |
return f"Error: Unsupported form type {form_type}" | |
# Create structured LLM | |
structured_llm = llm.with_structured_output(schema=schema) | |
# Load prompts | |
system_prompt, human_prompt = load_prompt(form_type.lower()) | |
# Create prompt template | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", human_prompt) | |
]) | |
# Process transcript | |
print(f"Generating {form_type} note...") | |
response = structured_llm.invoke(transcript) | |
# Store result in cache | |
result_dict = response.model_dump(exclude_unset=False, exclude_none=False) | |
cache_manager.store_extraction( | |
transcript, | |
form_type.lower(), | |
result_dict, | |
model_id | |
) | |
# Format the response | |
formatted_response = yaml.dump( | |
result_dict, | |
default_flow_style=False, | |
sort_keys=False | |
) | |
return f"## {form_type} Note:\n```yaml\n{formatted_response}\n```" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def create_ui() -> gr.Blocks: | |
"""Create the Gradio interface.""" | |
# Load authorized users from config | |
auth = load_auth_credentials() | |
def check_auth(username: str, password: str) -> bool: | |
"""Check if username and password are valid.""" | |
return username in auth and auth[username] == password | |
with gr.Blocks(title="Therapy Note Generator") as demo: | |
# Login interface | |
with gr.Row(): | |
with gr.Column(): | |
username = gr.Textbox(label="Username") | |
password = gr.Textbox(label="Password", type="password") | |
login_btn = gr.Button("Login") | |
login_msg = gr.Markdown() | |
# Main interface (initially invisible) | |
with gr.Column(visible=False) as main_interface: | |
gr.Markdown("# Therapy Note Generator") | |
gr.Markdown(""" | |
Enter a YouTube URL, paste a transcript directly, or upload an audio file. | |
Select the desired note format and click 'Generate' to create a structured note. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Input type selector | |
input_type = gr.Radio( | |
choices=["text", "youtube", "audio"], | |
value="text", | |
label="Input Type", | |
info="Choose how you want to provide the therapy session" | |
) | |
# Text input for transcript or YouTube URL | |
input_text = gr.Textbox( | |
label="Text Input", | |
placeholder="Enter transcript or YouTube URL here...", | |
lines=10, | |
visible=True | |
) | |
# Audio upload | |
audio_input = gr.Audio( | |
label="Audio Input", | |
type="filepath", | |
visible=False | |
) | |
# Note format selector | |
form_type = gr.Dropdown( | |
choices=list(FORM_SCHEMAS.keys()), | |
value="SOAP", | |
label="Note Format" | |
) | |
generate_btn = gr.Button("Generate Note", variant="primary") | |
with gr.Column(): | |
# Transcript output | |
transcript_output = gr.Textbox( | |
label="Generated Transcript", | |
lines=10, | |
visible=False, | |
interactive=False | |
) | |
# Structured note output | |
note_output = gr.Markdown(label="Generated Note") | |
# Update visibility based on input type | |
def update_inputs(choice): | |
return { | |
input_text: gr.update(visible=choice in ["text", "youtube"]), | |
audio_input: gr.update(visible=choice == "audio"), | |
transcript_output: gr.update(visible=choice in ["youtube", "audio"]) | |
} | |
input_type.change( | |
fn=update_inputs, | |
inputs=input_type, | |
outputs=[input_text, audio_input, transcript_output] | |
) | |
def process_and_show_transcript( | |
input_text: str, | |
form_type: str, | |
input_type: str = "text", | |
audio_file: str | None = None, | |
force_refresh: bool = False | |
) -> tuple[str, str]: | |
"""Process input and return both transcript and structured note.""" | |
try: | |
# Get transcript based on input type | |
if input_type == "audio" and audio_file: | |
print("Processing audio file...") | |
transcript = transcribe_audio(audio_file) | |
elif "youtube.com" in input_text or "youtu.be" in input_text: | |
print(f"Downloading transcript from YouTube...") | |
video_id = extract_youtube_video_id(input_text) | |
# Check cache first | |
if not force_refresh: | |
cached_transcript = cache_manager.get_transcript(video_id) | |
if cached_transcript: | |
print("Using cached transcript...") | |
transcript = cached_transcript | |
else: | |
transcript = download_transcript(input_text) | |
cache_manager.store_transcript(video_id, transcript) | |
else: | |
transcript = download_transcript(input_text) | |
cache_manager.store_transcript(video_id, transcript) | |
else: | |
print("Using provided text directly...") | |
transcript = input_text | |
# Process the transcript to generate the note | |
note_output = process_input(input_text, form_type, input_type, audio_file, force_refresh) | |
return transcript, note_output | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
return error_msg, error_msg | |
# Handle generate button click | |
generate_btn.click( | |
fn=process_and_show_transcript, | |
inputs=[input_text, form_type, input_type, audio_input], | |
outputs=[transcript_output, note_output] | |
) | |
# Example inputs | |
try: | |
with open("data/sample_note.txt", "r") as f: | |
sample_text = f.read() | |
except FileNotFoundError: | |
sample_text = "Sample therapy session transcript..." | |
gr.Examples( | |
examples=[ | |
# Text example | |
[sample_text, "SOAP", "text", None], | |
# YouTube examples | |
["https://www.youtube.com/watch?v=KuHLL2AE-SE", "DAP", "youtube", None], | |
["https://www.youtube.com/watch?v=jS1KE3_Pqlc", "SOAPIER", "youtube", None], | |
# Audio example | |
[None, "BIRP", "audio", "data/CBT Role-Play.mp3"] | |
], | |
inputs=[input_text, form_type, input_type, audio_input], | |
outputs=[transcript_output, note_output], | |
fn=process_and_show_transcript, | |
cache_examples=False, | |
label="Example Inputs", | |
examples_per_page=4 | |
) | |
def login(username: str, password: str): | |
"""Handle login and return updates for UI components.""" | |
if check_auth(username, password): | |
return [ | |
gr.update(visible=True), # main_interface | |
gr.update(value="✅ Login successful!", visible=True), # login_msg | |
gr.update(visible=False), # username | |
gr.update(visible=False), # password | |
gr.update(visible=False), # login_btn | |
] | |
else: | |
return [ | |
gr.update(visible=False), # main_interface | |
gr.update(value="❌ Invalid credentials", visible=True), # login_msg | |
gr.update(), # username - no change | |
gr.update(), # password - no change | |
gr.update(), # login_btn - no change | |
] | |
login_btn.click( | |
fn=login, | |
inputs=[username, password], | |
outputs=[main_interface, login_msg, username, password, login_btn] | |
) | |
return demo | |
if __name__ == "__main__": | |
# Clean up any existing Gradio cache | |
cache_manager.cleanup_gradio_cache() | |
demo = create_ui() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True, | |
auth=None # We're using our own auth system instead of Gradio's | |
) |