Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch.nn.functional as F | |
| # Define the model class | |
| class MedicalCodePredictor(torch.nn.Module): | |
| def __init__(self, bert_model): | |
| super().__init__() | |
| self.bert = bert_model | |
| self.dropout = torch.nn.Dropout(0.1) | |
| self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES)) | |
| self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES)) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.last_hidden_state[:, 0, :] | |
| pooled_output = self.dropout(pooled_output) | |
| icd_logits = self.icd_classifier(pooled_output) | |
| cpt_logits = self.cpt_classifier(pooled_output) | |
| return icd_logits, cpt_logits | |
| # Define code dictionaries | |
| ICD_CODES = { | |
| 0: "I10 - Essential hypertension", | |
| 1: "E11.9 - Type 2 diabetes without complications", | |
| 2: "J44.9 - COPD", | |
| 3: "I25.10 - Atherosclerotic heart disease", | |
| 4: "M54.5 - Low back pain", | |
| 5: "F41.9 - Anxiety disorder", | |
| 6: "J45.909 - Asthma, unspecified", | |
| 7: "K21.9 - GERD", | |
| 8: "E78.5 - Dyslipidemia", | |
| 9: "M17.9 - Osteoarthritis of knee", | |
| 10: "E10.9 - Type 1 diabetes without complications", | |
| 11: "R51 - Headache", | |
| 12: "R50.9 - Fever, unspecified", | |
| 13: "R05 - Cough", | |
| 14: "S52.5 - Fracture of forearm", | |
| 15: "A49.9 - Bacterial infection, unspecified", | |
| 16: "R52 - Pain, unspecified", | |
| 17: "R11 - Nausea", | |
| 18: "S33.5 - Sprain and strain of lumbar spine" | |
| } | |
| CPT_CODES = { | |
| 0: "99213 - Office visit, established patient", | |
| 1: "99214 - Office visit, established patient, moderate complexity", | |
| 2: "99203 - Office visit, new patient", | |
| 3: "80053 - Comprehensive metabolic panel", | |
| 4: "85025 - Complete blood count", | |
| 5: "93000 - ECG with interpretation", | |
| 6: "71045 - Chest X-ray", | |
| 7: "99395 - Preventive visit, established patient", | |
| 8: "96127 - Brief emotional/behavioral assessment", | |
| 9: "99396 - Preventive visit, age 40-64", | |
| 10: "96372 - Therapeutic injection", | |
| 11: "97110 - Therapeutic exercises", | |
| 12: "10060 - Incision and drainage of abscess", | |
| 13: "76700 - Abdominal ultrasound", | |
| 14: "87500 - Infectious agent detection", | |
| 15: "72100 - X-ray of lower spine", | |
| 16: "72148 - MRI of lumbar spine" | |
| } | |
| # Load models | |
| def load_models(): | |
| tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| model = MedicalCodePredictor(base_model) | |
| return tokenizer, model | |
| # Prediction function | |
| def predict_codes(text): | |
| if not text.strip(): | |
| return "Please enter a medical summary." | |
| # Tokenize input | |
| inputs = tokenizer(text, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| padding=True) | |
| # Get predictions | |
| model.eval() | |
| icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask']) | |
| # Get probabilities | |
| icd_probs = F.softmax(icd_logits, dim=1) | |
| cpt_probs = F.softmax(cpt_logits, dim=1) | |
| # Get top 3 predictions | |
| top_icd = torch.topk(icd_probs, k=3) | |
| top_cpt = torch.topk(cpt_probs, k=3) | |
| # Format results | |
| result = "Recommended ICD-10 Codes:\n" | |
| for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): | |
| result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n" | |
| result += "\nRecommended CPT Codes:\n" | |
| for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])): | |
| result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n" | |
| return result | |
| # Load models globally | |
| tokenizer, model = load_models() | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_codes, | |
| inputs=gr.Textbox( | |
| lines=5, | |
| placeholder="Enter medical summary here...", | |
| label="Medical Summary" | |
| ), | |
| outputs=gr.Textbox( | |
| label="Predicted Codes", | |
| lines=8 | |
| ), | |
| title="AutoRCM - Medical Code Predictor", | |
| description="Enter a medical summary to get recommended ICD-10 and CPT codes.", | |
| examples=[ | |
| ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], | |
| ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], | |
| ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] | |
| ] | |
| ) | |
| # Launch the interface | |
| iface.launch(share=True) | |