|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") |
|
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") |
|
|
|
|
|
s2t = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h") |
|
|
|
|
|
def extract_structured_info(note): |
|
questions = { |
|
"Patient Name": "What is the patient's name?", |
|
"Age": "How old is the patient?", |
|
"Medical History": "What is the medical history?", |
|
"Physical Examination": "What does the physical examination reveal?" |
|
} |
|
|
|
answers = {} |
|
|
|
for key, question in questions.items(): |
|
|
|
inputs = tokenizer(question, note, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
answer_start = outputs.start_logits.argmax() |
|
answer_end = outputs.end_logits.argmax() + 1 |
|
|
|
|
|
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) |
|
|
|
answers[key] = answer.strip() |
|
|
|
return answers |
|
|
|
def process_audio(audio): |
|
try: |
|
if audio is None: |
|
return "No audio provided", "N/A", "N/A", "N/A", "N/A" |
|
|
|
|
|
transcription_result = s2t(audio) |
|
transcription = transcription_result.get("text", "") |
|
|
|
|
|
structured_info = extract_structured_info(transcription) |
|
|
|
return ( |
|
transcription, |
|
structured_info["Patient Name"], |
|
structured_info["Age"], |
|
structured_info["Medical History"], |
|
structured_info["Physical Examination"] |
|
) |
|
|
|
except Exception as e: |
|
|
|
error_message = f"Error: {str(e)}" |
|
return error_message, error_message, error_message, error_message, error_message |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_audio, |
|
title="Medical Speech-to-Text with Entity Recognition", |
|
description="Record an audio file describing patient details. The app transcribes the text, extracts the patient's name, age, medical history, and physical examination.", |
|
inputs=gr.Audio(type="filepath", label="Record Audio"), |
|
outputs=[ |
|
gr.Textbox(label="Transcribed Text"), |
|
gr.Textbox(label="Patient Name"), |
|
gr.Textbox(label="Patient Age"), |
|
gr.Textbox(label="Medical History"), |
|
gr.Textbox(label="Physical Examination"), |
|
], |
|
live=True |
|
) |
|
iface.launch() |