Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import torch.nn.functional as F | |
# Define the model class | |
class MedicalCodePredictor(torch.nn.Module): | |
def __init__(self, bert_model): | |
super().__init__() | |
self.bert = bert_model | |
self.dropout = torch.nn.Dropout(0.1) | |
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES)) | |
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES)) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.last_hidden_state[:, 0, :] | |
pooled_output = self.dropout(pooled_output) | |
icd_logits = self.icd_classifier(pooled_output) | |
cpt_logits = self.cpt_classifier(pooled_output) | |
return icd_logits, cpt_logits | |
# Define code dictionaries | |
ICD_CODES = { | |
0: "I10 - Essential hypertension", | |
1: "E11.9 - Type 2 diabetes without complications", | |
2: "J44.9 - COPD", | |
3: "I25.10 - Atherosclerotic heart disease", | |
4: "M54.5 - Low back pain", | |
5: "F41.9 - Anxiety disorder", | |
6: "J45.909 - Asthma, unspecified", | |
7: "K21.9 - GERD", | |
8: "E78.5 - Dyslipidemia", | |
9: "M17.9 - Osteoarthritis of knee", | |
10: "E10.9 - Type 1 diabetes without complications", | |
11: "R51 - Headache", | |
12: "R50.9 - Fever, unspecified", | |
13: "R05 - Cough", | |
14: "S52.5 - Fracture of forearm", | |
15: "A49.9 - Bacterial infection, unspecified", | |
16: "R52 - Pain, unspecified", | |
17: "R11 - Nausea", | |
18: "S33.5 - Sprain and strain of lumbar spine" | |
} | |
CPT_CODES = { | |
0: "99213 - Office visit, established patient", | |
1: "99214 - Office visit, established patient, moderate complexity", | |
2: "99203 - Office visit, new patient", | |
3: "80053 - Comprehensive metabolic panel", | |
4: "85025 - Complete blood count", | |
5: "93000 - ECG with interpretation", | |
6: "71045 - Chest X-ray", | |
7: "99395 - Preventive visit, established patient", | |
8: "96127 - Brief emotional/behavioral assessment", | |
9: "99396 - Preventive visit, age 40-64", | |
10: "96372 - Therapeutic injection", | |
11: "97110 - Therapeutic exercises", | |
12: "10060 - Incision and drainage of abscess", | |
13: "76700 - Abdominal ultrasound", | |
14: "87500 - Infectious agent detection", | |
15: "72100 - X-ray of lower spine", | |
16: "72148 - MRI of lumbar spine" | |
} | |
# Load models | |
def load_models(): | |
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
model = MedicalCodePredictor(base_model) | |
return tokenizer, model | |
# Prediction function | |
def predict_codes(text): | |
if not text.strip(): | |
return "Please enter a medical summary." | |
# Tokenize input | |
inputs = tokenizer(text, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True) | |
# Get predictions | |
model.eval() | |
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask']) | |
# Get probabilities | |
icd_probs = F.softmax(icd_logits, dim=1) | |
cpt_probs = F.softmax(cpt_logits, dim=1) | |
# Get top 3 predictions | |
top_icd = torch.topk(icd_probs, k=3) | |
top_cpt = torch.topk(cpt_probs, k=3) | |
# Format results | |
result = "Recommended ICD-10 Codes:\n" | |
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): | |
result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n" | |
result += "\nRecommended CPT Codes:\n" | |
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])): | |
result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n" | |
return result | |
# Load models globally | |
tokenizer, model = load_models() | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_codes, | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter medical summary here...", | |
label="Medical Summary" | |
), | |
outputs=gr.Textbox( | |
label="Predicted Codes", | |
lines=8 | |
), | |
title="AutoRCM - Medical Code Predictor", | |
description="Enter a medical summary to get recommended ICD-10 and CPT codes.", | |
examples=[ | |
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], | |
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], | |
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] | |
] | |
) | |
# Launch the interface | |
iface.launch(share=True) | |