File size: 13,693 Bytes
6830eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
from __future__ import annotations
import os
from pathlib import Path
import yaml
import gradio as gr
from typing import Optional

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage

from config.settings import settings
from forms.schemas import (
    SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote,
    FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote,
    POMRNote, NarrativeNote, CBENote, SBARNote
)
from utils.youtube import download_transcript
from utils.youtube import extract_youtube_video_id
from utils.text_processing import chunk_text
from utils.audio import transcribe_audio
from models.llm_provider import get_llm, get_model_identifier
from utils.cache import CacheManager
from config.auth import load_auth_credentials

# Dictionary mapping form types to their schemas
FORM_SCHEMAS = {
    "SOAP": SOAPNote,
    "DAP": DAPNote,
    "BIRP": BIRPNote,
    "PIRP": PIRPNote,
    "GIRP": GIRPNote,
    "SIRP": SIRPNote,
    "FAIR/F-DARP": FAIRFDARPNote,
    "DARE": DARENote,
    "PIE": PIENote,
    "SOAPIER": SOAPIERNote,
    "SOAPIE": SOAPIENote,
    "POMR": POMRNote,
    "Narrative": NarrativeNote,
    "CBE": CBENote,
    "SBAR": SBARNote,
}

# Initialize cache manager
cache_manager = CacheManager()

def load_prompt(note_type: str) -> tuple[str, str]:
    """Load the prompt template from YAML for the specified note type."""
    prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml")
    with open(prompt_path, "r") as f:
        data = yaml.safe_load(f)
    
    note_prompts = data.get("prompts", {}).get(note_type.lower())
    if not note_prompts:
        raise ValueError(f"No prompt template found for note type: {note_type}")
    
    return note_prompts["system"], note_prompts["human"]

def process_input(
    input_text: str,
    form_type: str,
    input_type: str = "text",
    audio_file: str | None = None,
    force_refresh: bool = False
) -> str:
    """Process input (text, YouTube URL, or audio) and generate notes."""
    try:
        # Get transcript based on input type
        if input_type == "audio" and audio_file:
            print("Processing audio file...")
            transcript = transcribe_audio(audio_file)
        elif "youtube.com" in input_text or "youtu.be" in input_text:
            print(f"Downloading transcript from YouTube...")
            video_id = extract_youtube_video_id(input_text)
            
            # Check cache first
            if not force_refresh:
                cached_transcript = cache_manager.get_transcript(video_id)
                if cached_transcript:
                    print("Using cached transcript...")
                    transcript = cached_transcript
                else:
                    transcript = download_transcript(input_text)
                    cache_manager.store_transcript(video_id, transcript)
            else:
                transcript = download_transcript(input_text)
                cache_manager.store_transcript(video_id, transcript)
        else:
            print("Using provided text directly...")
            transcript = input_text

        # Initialize LLM
        llm = get_llm()
        model_id = get_model_identifier(llm)
        
        # Check extraction cache
        if not force_refresh:
            cached_result = cache_manager.get_extraction(
                transcript,
                form_type.lower(),
                model_id
            )
            if cached_result:
                print("Using cached extraction result...")
                formatted_response = yaml.dump(
                    cached_result,
                    default_flow_style=False,
                    sort_keys=False
                )
                return f"## {form_type} Note:\n```yaml\n{formatted_response}\n```"
        
        # Get schema for selected form type
        schema = FORM_SCHEMAS.get(form_type)
        if not schema:
            return f"Error: Unsupported form type {form_type}"
        
        # Create structured LLM
        structured_llm = llm.with_structured_output(schema=schema)
        
        # Load prompts
        system_prompt, human_prompt = load_prompt(form_type.lower())
        
        # Create prompt template
        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", human_prompt)
        ])
        
        # Process transcript
        print(f"Generating {form_type} note...")
        response = structured_llm.invoke(transcript)
        
        # Store result in cache
        result_dict = response.model_dump(exclude_unset=False, exclude_none=False)
        cache_manager.store_extraction(
            transcript,
            form_type.lower(),
            result_dict,
            model_id
        )
        
        # Format the response
        formatted_response = yaml.dump(
            result_dict,
            default_flow_style=False,
            sort_keys=False
        )
        
        return f"## {form_type} Note:\n```yaml\n{formatted_response}\n```"
        
    except Exception as e:
        return f"Error: {str(e)}"

def create_ui() -> gr.Blocks:
    """Create the Gradio interface."""
    
    # Load authorized users from config
    auth = load_auth_credentials()
    
    def check_auth(username: str, password: str) -> bool:
        """Check if username and password are valid."""
        return username in auth and auth[username] == password
    
    with gr.Blocks(title="Therapy Note Generator") as demo:
        # Login interface
        with gr.Row():
            with gr.Column():
                username = gr.Textbox(label="Username")
                password = gr.Textbox(label="Password", type="password")
                login_btn = gr.Button("Login")
                login_msg = gr.Markdown()
        
        # Main interface (initially invisible)
        with gr.Column(visible=False) as main_interface:
            gr.Markdown("# Therapy Note Generator")
            gr.Markdown("""
            Enter a YouTube URL, paste a transcript directly, or upload an audio file.
            Select the desired note format and click 'Generate' to create a structured note.
            """)
            
            with gr.Row():
                with gr.Column():
                    # Input type selector
                    input_type = gr.Radio(
                        choices=["text", "youtube", "audio"],
                        value="text",
                        label="Input Type",
                        info="Choose how you want to provide the therapy session"
                    )
                    
                    # Text input for transcript or YouTube URL
                    input_text = gr.Textbox(
                        label="Text Input",
                        placeholder="Enter transcript or YouTube URL here...",
                        lines=10,
                        visible=True
                    )
                    
                    # Audio upload
                    audio_input = gr.Audio(
                        label="Audio Input",
                        type="filepath",
                        visible=False
                    )
                    
                    # Note format selector
                    form_type = gr.Dropdown(
                        choices=list(FORM_SCHEMAS.keys()),
                        value="SOAP",
                        label="Note Format"
                    )
                    
                    generate_btn = gr.Button("Generate Note", variant="primary")
                
                with gr.Column():
                    # Transcript output
                    transcript_output = gr.Textbox(
                        label="Generated Transcript",
                        lines=10,
                        visible=False,
                        interactive=False
                    )
                    # Structured note output
                    note_output = gr.Markdown(label="Generated Note")
            
            # Update visibility based on input type
            def update_inputs(choice):
                return {
                    input_text: gr.update(visible=choice in ["text", "youtube"]),
                    audio_input: gr.update(visible=choice == "audio"),
                    transcript_output: gr.update(visible=choice in ["youtube", "audio"])
                }
            
            input_type.change(
                fn=update_inputs,
                inputs=input_type,
                outputs=[input_text, audio_input, transcript_output]
            )
            
            def process_and_show_transcript(
                input_text: str,
                form_type: str,
                input_type: str = "text",
                audio_file: str | None = None,
                force_refresh: bool = False
            ) -> tuple[str, str]:
                """Process input and return both transcript and structured note."""
                try:
                    # Get transcript based on input type
                    if input_type == "audio" and audio_file:
                        print("Processing audio file...")
                        transcript = transcribe_audio(audio_file)
                    elif "youtube.com" in input_text or "youtu.be" in input_text:
                        print(f"Downloading transcript from YouTube...")
                        video_id = extract_youtube_video_id(input_text)
                        
                        # Check cache first
                        if not force_refresh:
                            cached_transcript = cache_manager.get_transcript(video_id)
                            if cached_transcript:
                                print("Using cached transcript...")
                                transcript = cached_transcript
                            else:
                                transcript = download_transcript(input_text)
                                cache_manager.store_transcript(video_id, transcript)
                        else:
                            transcript = download_transcript(input_text)
                            cache_manager.store_transcript(video_id, transcript)
                    else:
                        print("Using provided text directly...")
                        transcript = input_text
                        
                    # Process the transcript to generate the note
                    note_output = process_input(input_text, form_type, input_type, audio_file, force_refresh)
                    
                    return transcript, note_output
                    
                except Exception as e:
                    error_msg = f"Error: {str(e)}"
                    return error_msg, error_msg
            
            # Handle generate button click
            generate_btn.click(
                fn=process_and_show_transcript,
                inputs=[input_text, form_type, input_type, audio_input],
                outputs=[transcript_output, note_output]
            )
            
            # Example inputs
            try:
                with open("data/sample_note.txt", "r") as f:
                    sample_text = f.read()
            except FileNotFoundError:
                sample_text = "Sample therapy session transcript..."
            
            gr.Examples(
                examples=[
                    # Text example
                    [sample_text, "SOAP", "text", None],
                    # YouTube examples
                    ["https://www.youtube.com/watch?v=KuHLL2AE-SE", "DAP", "youtube", None],
                    ["https://www.youtube.com/watch?v=jS1KE3_Pqlc", "SOAPIER", "youtube", None],
                    # Audio example
                    [None, "BIRP", "audio", "data/CBT Role-Play.mp3"]
                ],
                inputs=[input_text, form_type, input_type, audio_input],
                outputs=[transcript_output, note_output],
                fn=process_and_show_transcript,
                cache_examples=False,
                label="Example Inputs",
                examples_per_page=4
            )
        
        def login(username: str, password: str):
            """Handle login and return updates for UI components."""
            if check_auth(username, password):
                return [
                    gr.update(visible=True),  # main_interface
                    gr.update(value="✅ Login successful!", visible=True),  # login_msg
                    gr.update(visible=False),  # username
                    gr.update(visible=False),  # password
                    gr.update(visible=False),  # login_btn
                ]
            else:
                return [
                    gr.update(visible=False),  # main_interface
                    gr.update(value="❌ Invalid credentials", visible=True),  # login_msg
                    gr.update(),  # username - no change
                    gr.update(),  # password - no change
                    gr.update(),  # login_btn - no change
                ]
        
        login_btn.click(
            fn=login,
            inputs=[username, password],
            outputs=[main_interface, login_msg, username, password, login_btn]
        )
    
    return demo

if __name__ == "__main__":
    # Clean up any existing Gradio cache
    cache_manager.cleanup_gradio_cache()
    
    demo = create_ui()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        show_error=True,
        auth=None  # We're using our own auth system instead of Gradio's
    )