import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline

# Load the question-answering model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")

# Load the speech-to-text model
s2t = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")

# Function to extract structured information using question answering
def extract_structured_info(note):
    questions = {
        "Patient Name": "What is the patient's name?",
        "Age": "How old is the patient?",
        "Medical History": "What is the medical history?",
        "Physical Examination": "What does the physical examination reveal?"
    }
    
    answers = {}
    
    for key, question in questions.items():
        # Encode the question and note
        inputs = tokenizer(question, note, return_tensors="pt")
        
        # Get model predictions
        with torch.no_grad():
            outputs = model(**inputs)

        # Extract the answer
        answer_start = outputs.start_logits.argmax()
        answer_end = outputs.end_logits.argmax() + 1
        
        # Decode the answer
        answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

        answers[key] = answer.strip()

    return answers

def process_audio(audio):
    try:
        if audio is None:
            return "No audio provided", "N/A", "N/A", "N/A", "N/A"

        # Transcribe audio to text
        transcription_result = s2t(audio)
        transcription = transcription_result.get("text", "")

        # Extract structured information
        structured_info = extract_structured_info(transcription)

        return (
            transcription, 
            structured_info["Patient Name"], 
            structured_info["Age"], 
            structured_info["Medical History"], 
            structured_info["Physical Examination"]
        )

    except Exception as e:
        # Capture any errors and display them in output fields
        error_message = f"Error: {str(e)}"
        return error_message, error_message, error_message, error_message, error_message

# Set up Gradio Interface with structured outputs
iface = gr.Interface(
    fn=process_audio,
    title="Medical Speech-to-Text with Entity Recognition",
    description="Record an audio file describing patient details. The app transcribes the text, extracts the patient's name, age, medical history, and physical examination.",
    inputs=gr.Audio(type="filepath", label="Record Audio"),
    outputs=[
        gr.Textbox(label="Transcribed Text"),
        gr.Textbox(label="Patient Name"),
        gr.Textbox(label="Patient Age"),
        gr.Textbox(label="Medical History"),
        gr.Textbox(label="Physical Examination"),
    ],
    live=True  # Automatically triggers on new input
)
iface.launch()