chrisaldikaraharja's picture
Update app.py
866ef54 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
# Load the question-answering model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
# Load the speech-to-text model
s2t = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
# Function to extract structured information using question answering
def extract_structured_info(note):
questions = {
"Patient Name": "What is the patient's name?",
"Age": "How old is the patient?",
"Medical History": "What is the medical history?",
"Physical Examination": "What does the physical examination reveal?"
}
answers = {}
for key, question in questions.items():
# Encode the question and note
inputs = tokenizer(question, note, return_tensors="pt")
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
# Extract the answer
answer_start = outputs.start_logits.argmax()
answer_end = outputs.end_logits.argmax() + 1
# Decode the answer
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
answers[key] = answer.strip()
return answers
def process_audio(audio):
try:
if audio is None:
return "No audio provided", "N/A", "N/A", "N/A", "N/A"
# Transcribe audio to text
transcription_result = s2t(audio)
transcription = transcription_result.get("text", "")
# Extract structured information
structured_info = extract_structured_info(transcription)
return (
transcription,
structured_info["Patient Name"],
structured_info["Age"],
structured_info["Medical History"],
structured_info["Physical Examination"]
)
except Exception as e:
# Capture any errors and display them in output fields
error_message = f"Error: {str(e)}"
return error_message, error_message, error_message, error_message, error_message
# Set up Gradio Interface with structured outputs
iface = gr.Interface(
fn=process_audio,
title="Medical Speech-to-Text with Entity Recognition",
description="Record an audio file describing patient details. The app transcribes the text, extracts the patient's name, age, medical history, and physical examination.",
inputs=gr.Audio(type="filepath", label="Record Audio"),
outputs=[
gr.Textbox(label="Transcribed Text"),
gr.Textbox(label="Patient Name"),
gr.Textbox(label="Patient Age"),
gr.Textbox(label="Medical History"),
gr.Textbox(label="Physical Examination"),
],
live=True # Automatically triggers on new input
)
iface.launch()