Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
# Mock ICD and CPT data (replace with actual API calls or datasets) | |
def fetch_icd_codes(query): | |
# Mock ICD codes for demonstration | |
return [ | |
{"code": "R50.9", "description": "Fever, unspecified"}, | |
{"code": "A00", "description": "Cholera"}, | |
{"code": "J06.9", "description": "Acute upper respiratory infection, unspecified"} | |
] | |
def fetch_cpt_codes(query): | |
# Mock CPT codes for demonstration | |
return [ | |
{"code": "99213", "description": "Office or other outpatient visit"}, | |
{"code": "87804", "description": "Infectious agent detection by immunoassay"}, | |
{"code": "85025", "description": "Complete blood count (CBC)"} | |
] | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=1000) # Adjust num_labels as needed | |
# Prediction function | |
def predict_codes(text): | |
if not text.strip(): | |
return "Please enter a medical summary." | |
# Tokenize input | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
# Get predictions | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get probabilities | |
probs = F.softmax(logits, dim=1) | |
# Get top 3 predictions | |
top_k = torch.topk(probs, k=3) | |
# Fetch ICD and CPT codes using mock functions | |
icd_results = fetch_icd_codes(text) | |
cpt_results = fetch_cpt_codes(text) | |
# Format results | |
result = "Recommended ICD-10 Codes:\n" | |
for i, code in enumerate(icd_results[:3]): # Show top 3 ICD codes | |
result += f"{i+1}. {code.get('code', 'Unknown')}: {code.get('description', 'No description')}\n" | |
result += "\nRecommended CPT Codes:\n" | |
for i, code in enumerate(cpt_results[:3]): # Show top 3 CPT codes | |
result += f"{i+1}. {code.get('code', 'Unknown')}: {code.get('description', 'No description')}\n" | |
return result | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_codes, | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter medical summary here...", | |
label="Medical Summary" | |
), | |
outputs=gr.Textbox( | |
label="Predicted Codes", | |
lines=10 | |
), | |
title="AutoRCM - Medical Code Predictor", | |
description="Enter a medical summary to get recommended ICD-10 and CPT codes.", | |
examples=[ | |
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], | |
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], | |
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] | |
], | |
allow_flagging="never" # Disable caching | |
) | |
# Launch the interface | |
iface.launch(share=True) |