AutoRCM / app_Backup.py
mohanjebaraj's picture
Rename app.py to app_Backup.py
73d468b verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
# Define the model class
class MedicalCodePredictor(torch.nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
self.dropout = torch.nn.Dropout(0.1)
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :]
pooled_output = self.dropout(pooled_output)
icd_logits = self.icd_classifier(pooled_output)
cpt_logits = self.cpt_classifier(pooled_output)
return icd_logits, cpt_logits
# Define code dictionaries
ICD_CODES = {
0: "I10 - Essential hypertension",
1: "E11.9 - Type 2 diabetes without complications",
2: "J44.9 - COPD",
3: "I25.10 - Atherosclerotic heart disease",
4: "M54.5 - Low back pain",
5: "F41.9 - Anxiety disorder",
6: "J45.909 - Asthma, unspecified",
7: "K21.9 - GERD",
8: "E78.5 - Dyslipidemia",
9: "M17.9 - Osteoarthritis of knee",
10: "E10.9 - Type 1 diabetes without complications",
11: "R51 - Headache",
12: "R50.9 - Fever, unspecified",
13: "R05 - Cough",
14: "S52.5 - Fracture of forearm",
15: "A49.9 - Bacterial infection, unspecified",
16: "R52 - Pain, unspecified",
17: "R11 - Nausea",
18: "S33.5 - Sprain and strain of lumbar spine"
}
CPT_CODES = {
0: "99213 - Office visit, established patient",
1: "99214 - Office visit, established patient, moderate complexity",
2: "99203 - Office visit, new patient",
3: "80053 - Comprehensive metabolic panel",
4: "85025 - Complete blood count",
5: "93000 - ECG with interpretation",
6: "71045 - Chest X-ray",
7: "99395 - Preventive visit, established patient",
8: "96127 - Brief emotional/behavioral assessment",
9: "99396 - Preventive visit, age 40-64",
10: "96372 - Therapeutic injection",
11: "97110 - Therapeutic exercises",
12: "10060 - Incision and drainage of abscess",
13: "76700 - Abdominal ultrasound",
14: "87500 - Infectious agent detection",
15: "72100 - X-ray of lower spine",
16: "72148 - MRI of lumbar spine"
}
# Load models
@torch.no_grad()
def load_models():
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = MedicalCodePredictor(base_model)
return tokenizer, model
# Prediction function
def predict_codes(text):
if not text.strip():
return "Please enter a medical summary."
# Tokenize input
inputs = tokenizer(text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True)
# Get predictions
model.eval()
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
# Get probabilities
icd_probs = F.softmax(icd_logits, dim=1)
cpt_probs = F.softmax(cpt_logits, dim=1)
# Get top 3 predictions
top_icd = torch.topk(icd_probs, k=3)
top_cpt = torch.topk(cpt_probs, k=3)
# Format results
result = "Recommended ICD-10 Codes:\n"
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
result += "\nRecommended CPT Codes:\n"
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
return result
# Load models globally
tokenizer, model = load_models()
# Create Gradio interface
iface = gr.Interface(
fn=predict_codes,
inputs=gr.Textbox(
lines=5,
placeholder="Enter medical summary here...",
label="Medical Summary"
),
outputs=gr.Textbox(
label="Predicted Codes",
lines=8
),
title="AutoRCM - Medical Code Predictor",
description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
examples=[
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
]
)
# Launch the interface
iface.launch(share=True)