|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import joblib |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
from xgboost import XGBClassifier |
|
from sklearn.preprocessing import StandardScaler |
|
from sklearn.decomposition import PCA |
|
from sklearn.metrics import precision_recall_curve, roc_curve, confusion_matrix, classification_report |
|
import matplotlib.pyplot as plt |
|
import shap |
|
import plotly.express as px |
|
import streamlit as st |
|
import pandas as pd |
|
import datetime |
|
import json |
|
import requests |
|
from streamlit_lottie import st_lottie |
|
import streamlit.components.v1 as components |
|
from streamlit_navigation_bar import st_navbar |
|
from transformers import AutoTokenizer, AutoModel |
|
import re |
|
from tqdm import tqdm |
|
import torch |
|
import os |
|
from hugchat.login import Login |
|
from hugchat import hugchat |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import torch.nn as nn |
|
import time |
|
|
|
|
|
|
|
|
|
st.sidebar.title("๐ Menu") |
|
page = st.sidebar.radio( |
|
"Selecione uma opรงรฃo:", |
|
["๐ Home", "๐ Tabular Data", "๐ Clinical Text Notes", "๐ Ensemble Prediction"] |
|
) |
|
if page=="๐ Home": |
|
|
|
st.markdown(""" |
|
<style> |
|
.title { |
|
text-align: center; |
|
font-size: 36px; |
|
font-weight: bold; |
|
color: #2C3E50; |
|
} |
|
.subtitle { |
|
text-align: center; |
|
font-size: 22px; |
|
color: #7F8C8D; |
|
} |
|
.box { |
|
background-color: #ECF0F1; |
|
padding: 15px; |
|
border-radius: 10px; |
|
text-align: center; |
|
margin-bottom: 10px; |
|
font-size: 18px; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("<h1 class='title'>๐ AI Clinical Readmission Predictor</h1>", unsafe_allow_html=True) |
|
st.markdown("<h2 class='subtitle'>Using Machine Learning for Better Patient Outcomes</h2>", unsafe_allow_html=True) |
|
image_1 ='https://content.presspage.com/uploads/2110/4970f578-5f20-4675-acc2-3b2cda25fa96/1920_ai-machine-learning-cedars-sinai.jpg?10000' |
|
image_2 = 'https://med-tech.world/app/uploads/2024/10/AI-Hospitals.jpg.webp' |
|
|
|
|
|
st.image(image_2, width=1450) |
|
|
|
st.write("This app helps predict patient readmission risk using machine learning models. " |
|
"Upload data, analyze clinical notes, and see predictions from our ensemble model.") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("<h3 style='text-align: center;'>๐ Explore the App</h3>", unsafe_allow_html=True) |
|
|
|
elif page== "๐ Tabular Data": |
|
|
|
|
|
def load_lottie(url): |
|
response = requests.get(url) |
|
if response.status_code != 200: |
|
return None |
|
return response.json() |
|
|
|
|
|
lottie_hello = load_lottie("https://assets7.lottiefiles.com/packages/lf20_jcikwtux.json") |
|
if lottie_hello: |
|
st_lottie(lottie_hello, speed=1, loop=True, height=200) |
|
|
|
|
|
df = pd.read_csv('/Users/joaopimenta/Downloads/ensemble_test.csv') |
|
|
|
|
|
st.title('๐ฅ Hospital Readmission Prediction') |
|
st.markdown(""" |
|
<h3 style='text-align: center; color: gray;'>Predict ICU hospital readmission using Artificial Intelligence</h3> |
|
""", unsafe_allow_html=True) |
|
st.markdown("---") |
|
|
|
|
|
def get_age_group(age): |
|
"""Classify age into predefined groups with correct column names.""" |
|
if 36 <= age <= 50: |
|
return "age_group_36-50 (Middle-Aged Adults)" |
|
elif 51 <= age <= 65: |
|
return "age_group_51-65 (Older Middle-Aged Adults)" |
|
elif 66 <= age <= 80: |
|
return "age_group_66-80 (Senior Adults)" |
|
elif age >= 81: |
|
return "age_group_81+ (Elderly)" |
|
return "age_group_Below_36" |
|
|
|
|
|
def get_period(hour): |
|
"""Determine admission/discharge period.""" |
|
return "Morning" if 6 <= hour < 18 else "Night" |
|
|
|
|
|
st.subheader("๐ Select the admission's Characteristics") |
|
|
|
admission_type = st.selectbox("๐ Type of Admission", df.columns[df.columns.str.startswith('admission_type_')]) |
|
admission_location = st.selectbox("๐ Admission Location", df.columns[df.columns.str.startswith('admission_location_')]) |
|
discharge_location = st.selectbox("๐ฅ Discharge Location", df.columns[df.columns.str.startswith('discharge_location_')]) |
|
insurance = st.selectbox("๐ฐ Insurance Type", df.columns[df.columns.str.startswith('insurance_')]) |
|
|
|
st.sidebar.subheader("๐ Patient Information") |
|
language = st.sidebar.selectbox("๐ฃ Language", df.columns[df.columns.str.startswith('language_')]) |
|
marital_status = st.sidebar.selectbox("๐ Marital Status", df.columns[df.columns.str.startswith('marital_status_')]) |
|
race = st.sidebar.selectbox("๐ง Race", df.columns[df.columns.str.startswith('race_')]) |
|
sex = st.sidebar.selectbox("โง Sex", ['gender_M', 'gender_F']) |
|
age = st.sidebar.slider("๐
Age", 18, 100, 50) |
|
|
|
admission_time = st.time_input("โณ Admission Time", value=datetime.time(12, 0)) |
|
discharge_time = st.time_input("โณ Discharge Time", value=datetime.time(12, 0)) |
|
|
|
|
|
st.subheader("๐ Clinical Values") |
|
numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', 'drg_mortality', 'time_since_last_stay', |
|
'blood_cells', 'hemoglobin', 'glucose', 'creatine', 'plaquete'] |
|
numeric_inputs = {} |
|
cols = st.columns(len(numerical_features)) |
|
|
|
|
|
st.subheader("๐ General Hosptal Information") |
|
general_numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', |
|
'drg_mortality', 'time_since_last_stay'] |
|
|
|
general_inputs = {} |
|
cols = st.columns(3) |
|
|
|
for i, feature in enumerate(general_numerical_features): |
|
col_index = i % 3 |
|
min_val, max_val = df[feature].min(), df[feature].max() |
|
|
|
with cols[col_index]: |
|
general_inputs[feature] = st.slider( |
|
f"๐ {feature.replace('_', ' ').title()}", |
|
float(min_val), |
|
float(max_val), |
|
float((min_val + max_val) / 2) |
|
) |
|
|
|
|
|
st.subheader("๐งช Laboratory Test Results") |
|
lab_numerical_features = ['blood_cells', 'hemoglobin', 'glucose', |
|
'creatine', 'plaquete'] |
|
|
|
lab_inputs = {} |
|
lab_cols = st.columns(3) |
|
|
|
for i, feature in enumerate(lab_numerical_features): |
|
col_index = i % 3 |
|
min_val, max_val = df[feature].min(), df[feature].max() |
|
|
|
with lab_cols[col_index]: |
|
lab_inputs[feature] = st.slider( |
|
f"๐ฉธ {feature.replace('_', ' ').title()}", |
|
float(min_val), |
|
float(max_val), |
|
float((min_val + max_val) / 2) |
|
) |
|
min_val, max_val = df["cci_score"].min(), df["cci_score"].max() |
|
lab_inputs["cci_score"] = st.sidebar.slider( |
|
f"๐ CCI Score", |
|
float(min_val), |
|
float(max_val), |
|
float((min_val + max_val) / 2) |
|
) |
|
|
|
|
|
feature_vector = {col: 0 for col in df.columns} |
|
feature_vector.update({ |
|
admission_type: 1, |
|
admission_location: 1, |
|
discharge_location: 1, |
|
insurance: 1, |
|
language: 1, |
|
marital_status: 1, |
|
race: 1, |
|
"gender_M": 1 if sex == "gender_M" else 0, |
|
f"admit_period_{get_period(admission_time.hour)}": 1, |
|
f"discharge_period_{get_period(discharge_time.hour)}": 1 |
|
}) |
|
age_group = get_age_group(age) |
|
|
|
|
|
for group in [ |
|
"age_group_36-50 (Middle-Aged Adults)", |
|
"age_group_51-65 (Older Middle-Aged Adults)", |
|
"age_group_66-80 (Senior Adults)", |
|
"age_group_81+ (Elderly)" |
|
]: |
|
feature_vector[group] = 1 if group == age_group else 0 |
|
|
|
feature_vector.update(numeric_inputs) |
|
|
|
st.markdown("---") |
|
|
|
|
|
tabular_model_path = "/Users/joaopimenta/Downloads/final_xgboost_model.pkl" |
|
tabular_model = joblib.load(tabular_model_path) |
|
print("โ
XGBoost Tabular Model loaded successfully!") |
|
|
|
|
|
expected_columns = [ |
|
col for col in df.columns if col not in ["Unnamed: 0", "subject_id", "hadm_id", "probs"] |
|
] |
|
|
|
|
|
age_group_mapping = { |
|
"age_group_36-50": "age_group_36-50 (Middle-Aged Adults)", |
|
"age_group_51-65": "age_group_51-65 (Older Middle-Aged Adults)", |
|
"age_group_66-80": "age_group_66-80 (Senior Adults)", |
|
"age_group_81+": "age_group_81+ (Elderly)", |
|
} |
|
|
|
|
|
feature_vector = {col: 0 for col in df.columns} |
|
|
|
|
|
feature_vector.update({ |
|
admission_type: 1, |
|
admission_location: 1, |
|
discharge_location: 1, |
|
insurance: 1, |
|
language: 1, |
|
marital_status: 1, |
|
race: 1, |
|
"gender_M": 1 if sex == "gender_M" else 0, |
|
f"admit_period_{get_period(admission_time.hour)}": 1, |
|
f"discharge_period_{get_period(discharge_time.hour)}": 1 |
|
}) |
|
|
|
|
|
age_group = get_age_group(age) |
|
for group in [ |
|
"age_group_36-50 (Middle-Aged Adults)", |
|
"age_group_51-65 (Older Middle-Aged Adults)", |
|
"age_group_66-80 (Senior Adults)", |
|
"age_group_81+ (Elderly)" |
|
]: |
|
feature_vector[group] = 1 if group == age_group else 0 |
|
|
|
|
|
feature_vector.update(general_inputs) |
|
feature_vector.update(lab_inputs) |
|
|
|
|
|
fixed_feature_vector = {age_group_mapping.get(k, k): v for k, v in feature_vector.items()} |
|
feature_df = pd.DataFrame([fixed_feature_vector]).reindex(columns=expected_columns, fill_value=0) |
|
|
|
st.write(feature_df) |
|
|
|
prediction_proba = tabular_model.predict_proba(feature_df)[:, 1] |
|
probability = float(prediction_proba[0]) |
|
st.session_state["XGBoost probability"] = probability |
|
prediction = (prediction_proba >= 0.5).astype(int) |
|
|
|
import shap |
|
import matplotlib.pyplot as plt |
|
import streamlit.components.v1 as components |
|
|
|
st.write(f"Raw Prediction Probability: {probability:.4f}") |
|
|
|
|
|
if st.button("๐ Predict Readmission"): |
|
with st.spinner("๐ Processing Prediction..."): |
|
st.subheader("๐ฏ Prediction Results") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.metric(label="๐งฎ Readmission Probability", value=f"{probability:.2%}") |
|
|
|
with col2: |
|
if prediction == 1: |
|
st.error("โ ๏ธ High Risk of Readmission") |
|
else: |
|
st.success("โ
Low Risk of Readmission") |
|
|
|
|
|
if st.button("๐ Feature Importance for Prediction"): |
|
st.metric(label="๐งฎ Readmission Probability", value=f"{probability:.2%}") |
|
|
|
explainer = shap.TreeExplainer(tabular_model) |
|
shap_values = explainer.shap_values(feature_df) |
|
|
|
|
|
shap_df = pd.DataFrame({ |
|
"Feature": feature_df.columns, |
|
"SHAP Value": shap_values[0] |
|
}) |
|
|
|
|
|
shap_df["abs_SHAP"] = shap_df["SHAP Value"].abs() |
|
shap_df = shap_df.sort_values(by="abs_SHAP", ascending=False).head(10) |
|
|
|
|
|
top_features = sorted(zip(shap_df['Feature'], shap_df['SHAP Value']), key=lambda x: abs(x[1]), reverse=True) |
|
|
|
|
|
top_factors = "\n".join([f"- {feat}: {round(value, 2)} impact" for feat, value in top_features]) |
|
|
|
|
|
EMAIL = st.secrets["email"] |
|
PASSWD = st.secrets["passwd"] |
|
cookie_path_dir = "./cookies_snapshot" |
|
|
|
|
|
try: |
|
sign = Login(EMAIL, PASSWD) |
|
cookies = sign.login(cookie_dir_path=cookie_path_dir, save_cookies=True) |
|
sign.saveCookiesToDir(cookie_path_dir) |
|
except Exception as e: |
|
st.error(f"โ Login to HuggingChat failed. Error: {e}") |
|
st.stop() |
|
|
|
|
|
chatbot = hugchat.ChatBot(cookies=cookies.get_dict()) |
|
|
|
|
|
st.title("๐ฉบ AI-Powered Patient Readmission Analysis") |
|
|
|
|
|
|
|
hugging_prompt = f""" |
|
A hospital AI model predicts patient readmission based on the following feature impacts: |
|
{top_factors} |
|
|
|
Can you explain why the model made this decision? Specifically, what were the key characteristics of the patient or their admission that influenced the modelโs prediction the most? |
|
""" |
|
|
|
|
|
with st.spinner("๐ค Analyzing..."): |
|
try: |
|
response = chatbot.chat(hugging_prompt) |
|
ai_output = response |
|
|
|
with st.chat_message("assistant"): |
|
st.markdown(f"**๐ก AI Explanation:**\n\n{ai_output}") |
|
except Exception as e: |
|
st.error(f"โ ๏ธ Error retrieving response: {e}") |
|
st.stop() |
|
|
|
|
|
with st.expander("๐ Click to see detailed feature impacts"): |
|
st.markdown(f"```{top_factors}```") |
|
|
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
shap.bar_plot(shap_df["SHAP Value"].values, shap_df["Feature"].values) |
|
st.pyplot(fig) |
|
|
|
|
|
st.subheader("๐ฏ SHAP Force Plot (How Features Affected the Prediction)") |
|
|
|
|
|
force_plot = shap.force_plot( |
|
explainer.expected_value, shap_values[0], feature_df.iloc[0], matplotlib=False |
|
) |
|
|
|
|
|
shap_html = f"<head>{shap.getjs()}</head><body>{force_plot.html()}</body>" |
|
|
|
|
|
components.html(shap_html, height=400) |
|
|
|
elif page == "๐ Clinical Text Notes": |
|
|
|
st.subheader("๐ Clinical Text Note") |
|
|
|
|
|
|
|
def clean_text(text): |
|
"""Cleans input text by removing non-ASCII characters, extra spaces, and unwanted symbols.""" |
|
text = re.sub(r"[^\x20-\x7E]", " ", text) |
|
text = re.sub(r"_{2,}", "", text) |
|
text = re.sub(r"\s+", " ", text) |
|
text = re.sub(r"[^\w\s.,:;*%()\[\]-]", "", text) |
|
return text.lower().strip() |
|
|
|
|
|
import re |
|
|
|
def extract_fields(text): |
|
"""Extracts key fields from clinical notes using regex patterns.""" |
|
patterns = { |
|
"Discharge Medications": r"Discharge Medications[:\-]?\s*(.+?)\s+(?:Discharge Disposition|Discharge Condition|Discharge Instructions|Followup Instructions|$)", |
|
"Discharge Diagnosis": r"Discharge Diagnosis[:\-]?\s*(.+?)\s+(?:Discharge Condition|Discharge Medications|Discharge Instructions|Followup Instructions|$)", |
|
"Discharge Instructions": r"Discharge Instructions[:\-]?\s*(.*?)\s+(?:Followup Instructions|Discharge Disposition|Discharge Condition|$)", |
|
"History of Present Illness": r"History of Present Illness[:\-]?\s*(.+?)\s+(?:Past Medical History|Social History|Family History|Physical Exam|$)", |
|
"Past Medical History": r"Past Medical History[:\-]?\s*(.+?)\s+(?:Social History|Family History|Physical Exam|$)" |
|
} |
|
|
|
extracted_data = {} |
|
|
|
for field, pattern in patterns.items(): |
|
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) |
|
if match: |
|
extracted_data[field] = match.group(1).strip() |
|
|
|
return extracted_data |
|
|
|
def extract_features(texts, model, tokenizer, device, batch_size=8): |
|
"""Extracts CLS token embeddings from the Clinical-Longformer model.""" |
|
all_features = [] |
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i+batch_size] |
|
inputs = tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=4096).to(device) |
|
global_attention_mask = torch.zeros_like(inputs["input_ids"]).to(device) |
|
global_attention_mask[:, 0] = 1 |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs, global_attention_mask=global_attention_mask) |
|
|
|
all_features.append(outputs.last_hidden_state[:, 0, :]) |
|
|
|
return torch.cat(all_features, dim=0) |
|
|
|
|
|
def extract_entities(text, pipe, entity_group): |
|
"""Extracts specific entities from the clinical note using a NER pipeline.""" |
|
entities = pipe(text) |
|
return [ent['word'] for ent in entities if ent['entity_group'] == entity_group] or ["No relevant entities found"] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@st.cache_resource() |
|
def load_models(): |
|
"""Loads transformer models for text processing and NER.""" |
|
longformer_tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer") |
|
longformer_model = AutoModel.from_pretrained("yikuan8/Clinical-Longformer").to(device).eval() |
|
|
|
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") |
|
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") |
|
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") |
|
|
|
return longformer_tokenizer, longformer_model, ner_pipe |
|
|
|
longformer_tokenizer, longformer_model, ner_pipe = load_models() |
|
|
|
|
|
clinical_note = st.text_area("โ๏ธ Enter Clinical Note", placeholder="Write the clinical note here...") |
|
|
|
if clinical_note: |
|
cleaned_note = clean_text(clinical_note) |
|
|
|
|
|
|
|
|
|
extracted_data = extract_fields(cleaned_note) |
|
st.write("### Extracted Fields") |
|
st.write(extracted_data) |
|
|
|
|
|
with st.spinner("๐ Extracting embeddings..."): |
|
embeddings = extract_features([cleaned_note], longformer_model, longformer_tokenizer, device) |
|
|
|
|
|
|
|
class RobustMLPClassifier(nn.Module): |
|
def __init__(self, input_dim, hidden_dims=[256, 128, 64], dropout=0.3, activation=nn.ReLU()): |
|
super(RobustMLPClassifier, self).__init__() |
|
layers = [] |
|
current_dim = input_dim |
|
|
|
for h in hidden_dims: |
|
layers.append(nn.Linear(current_dim, h)) |
|
layers.append(nn.BatchNorm1d(h)) |
|
layers.append(activation) |
|
layers.append(nn.Dropout(dropout)) |
|
current_dim = h |
|
|
|
layers.append(nn.Linear(current_dim, 1)) |
|
self.net = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
mlp_model_path = "/Users/joaopimenta/Downloads/Capstone/best_mlp_model_full.pth" |
|
pca_path = "/Users/joaopimenta/Downloads/Capstone/best_pca_model.pkl" |
|
|
|
best_mlp_model = torch.load(mlp_model_path) |
|
best_mlp_model.to(device) |
|
best_mlp_model.eval() |
|
|
|
pca = joblib.load(pca_path) |
|
|
|
def predict_readmission(texts): |
|
"""Predicts hospital readmission probability using Clinical-Longformer embeddings and MLP.""" |
|
embeddings = extract_features(texts, longformer_model, longformer_tokenizer, device) |
|
embeddings_pca = pca.transform(embeddings.cpu().numpy()) |
|
|
|
inputs = torch.FloatTensor(embeddings_pca).to(device) |
|
|
|
with torch.no_grad(): |
|
logits = best_mlp_model(inputs) |
|
probabilities = torch.sigmoid(logits).cpu().numpy() |
|
|
|
return probabilities |
|
|
|
|
|
with st.spinner("๐ Identifying medical entities..."): |
|
extracted_data["Extracted Medications"] = extract_entities( |
|
extracted_data.get("Discharge Medications", ""), ner_pipe, "Medication" |
|
) |
|
|
|
extracted_data["Extracted Diseases"] = extract_entities( |
|
extracted_data.get("Discharge Diagnosis", ""), ner_pipe, "Disease_disorder" |
|
) |
|
|
|
extracted_data["Extracted Diseases (Past Medical History)"] = extract_entities( |
|
extracted_data.get("Past Medical History", ""), ner_pipe, "Disease_disorder" |
|
) |
|
|
|
extracted_data["Extracted Diseases (History of Present Illness)"] = extract_entities( |
|
extracted_data.get("History of Present Illness", ""), ner_pipe, "Disease_disorder" |
|
) |
|
|
|
|
|
extracted_data["Extracted Symptoms"] = extract_entities( |
|
extracted_data.get("Review of Systems", "") + " " + extracted_data.get("History of Present Illness", ""), |
|
ner_pipe, "Sign_symptom" |
|
) |
|
|
|
|
|
def clean_entities(entities): |
|
"""Reconstruct fragmented tokens and remove duplicates.""" |
|
cleaned = [] |
|
temp = "" |
|
|
|
for entity in entities: |
|
if entity.startswith("##"): |
|
temp += entity.replace("##", "") |
|
else: |
|
if temp: |
|
cleaned.append(temp) |
|
temp = entity |
|
if temp: |
|
cleaned.append(temp) |
|
|
|
|
|
cleaned = [word for word in cleaned if len(word) > 2 and not re.match(r"^[\W_]+$", word)] |
|
|
|
return sorted(set(cleaned)) |
|
|
|
|
|
diseases_cleaned = clean_entities( |
|
extracted_data.get("Extracted Diseases", []) + |
|
extracted_data.get("Extracted Diseases (Past Medical History)", []) + |
|
extracted_data.get("Extracted Diseases (History of Present Illness)", []) |
|
) |
|
|
|
medications_cleaned = clean_entities(extracted_data.get("Extracted Medications", [])) |
|
|
|
|
|
extracted_data["Extracted Medications Cleaned"] = medications_cleaned |
|
|
|
symptoms_cleaned = clean_entities(extracted_data.get("Extracted Symptoms", [])) |
|
|
|
|
|
def display_list(title, items, icon="๐"): |
|
"""Display extracted medical entities in an expandable list.""" |
|
with st.expander(f"**{title} ({len(items)})**"): |
|
if items: |
|
for item in items: |
|
st.markdown(f"- {icon} **{item}**") |
|
else: |
|
st.markdown("_No information available._") |
|
|
|
|
|
|
|
st.markdown("## ๐ฅ **Patient Medical Analysis**") |
|
st.markdown("---") |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
num_medications = len(medications_cleaned ) |
|
col1.metric(label="๐ Total Medications", value=num_medications) |
|
|
|
|
|
num_diseases = len(diseases_cleaned) |
|
col2.metric(label="๐ฆ Total Diseases", value=num_diseases) |
|
|
|
|
|
num_symptoms = len(symptoms_cleaned) |
|
col3.metric(label="๐ค Total Symptoms", value=num_symptoms) |
|
|
|
st.markdown("---") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
st.markdown("### ๐ **Medications**") |
|
display_list("Medication List", medications_cleaned , icon="๐") |
|
|
|
|
|
with col2: |
|
st.markdown("### ๐ฆ **Diseases**") |
|
display_list("Disease List", diseases_cleaned, icon="๐ฆ ") |
|
|
|
|
|
st.markdown("### ๐ค **Symptoms**") |
|
display_list("Symptoms List", symptoms_cleaned, icon="๐ค") |
|
|
|
st.markdown("---") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer") |
|
|
|
|
|
def count_tokens(text): |
|
tokens = tokenizer.tokenize(text) |
|
return len(tokens) |
|
|
|
def trunced_text(nr): |
|
return 1 if nr > 4096 else 0 |
|
|
|
|
|
disease_synonyms = { |
|
"Pneumonia": ["pneumonia", "pneumonitis"], |
|
"Diabetes": ["diabetes", "diabetes mellitus", "dm"], |
|
"CHF": ["CHF", "congestive heart failure", "heart failure"], |
|
"Septicemia": ["septicemia", "sepsis", "blood infection"], |
|
"Cirrhosis": ["cirrhosis", "liver cirrhosis", "hepatic cirrhosis"], |
|
"COPD": ["COPD", "chronic obstructive pulmonary disease"], |
|
"Renal_Failure": ["renal failure", "kidney failure", "chronic kidney disease", "CKD"] |
|
} |
|
|
|
|
|
extracted_data = extract_fields(clinical_note) |
|
|
|
|
|
number_of_tokens = count_tokens(clinical_note) |
|
number_of_tokens_med = count_tokens(extracted_data.get("Discharge Medications", "")) |
|
number_of_tokens_dis = count_tokens(extracted_data.get("Discharge Diagnosis", "")) |
|
trunced = trunced_text(number_of_tokens) |
|
|
|
|
|
full_diagnosis_text = extracted_data.get("Discharge Diagnosis", "").lower() |
|
|
|
|
|
def check_disease_presence(disease_list, text): |
|
return int(any(re.search(rf"\b{synonym}\b", text, re.IGNORECASE) for synonym in disease_list)) |
|
|
|
|
|
disease_flags = {disease: check_disease_presence(synonyms, full_diagnosis_text) |
|
for disease, synonyms in disease_synonyms.items()} |
|
|
|
|
|
disease_flags["total_conditions"] = sum(disease_flags.values()) |
|
|
|
|
|
df = pd.DataFrame([{ |
|
'number_of_tokens_dis': number_of_tokens_dis, |
|
'number_of_tokens': number_of_tokens, |
|
'number_of_tokens_med': number_of_tokens_med, |
|
'diagnostic_count': num_diseases, |
|
'total_conditions': disease_flags["total_conditions"], |
|
'trunced': trunced, |
|
**{disease: disease_flags[disease] for disease in disease_synonyms.keys()} |
|
}]) |
|
|
|
|
|
|
|
|
|
|
|
light_path = '/Users/joaopimenta/Downloads/best_lgbm_model.pkl' |
|
light_model = joblib.load(light_path) |
|
|
|
|
|
|
|
|
|
model_features = light_model.feature_name_ |
|
|
|
|
|
missing_features = [feat for feat in model_features if feat not in df.columns] |
|
if missing_features: |
|
st.write(f"โ ๏ธ Warning: Missing features in df: {missing_features}") |
|
|
|
|
|
for feat in missing_features: |
|
df[feat] = 0 |
|
|
|
|
|
df = df[model_features] |
|
|
|
|
|
X = df.to_numpy() |
|
|
|
|
|
|
|
light_probability = light_model.predict_proba(X)[:, 1] |
|
|
|
st.session_state["lightgbm probability"] = light_probability |
|
|
|
|
|
|
|
|
|
|
|
if st.button("๐ Predict Readmission"): |
|
with st.spinner("๐ Extracting embeddings and predicting readmission..."): |
|
readmission_prob = predict_readmission([cleaned_note])[0][0] |
|
st.session_state["MLP probability"] = readmission_prob |
|
prediction = 1 if readmission_prob > 0.5 else 0 |
|
|
|
|
|
st.subheader("๐ฏ Prediction Results") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.metric(label="๐งฎ Readmission Probability", value=f"{readmission_prob:.2%}") |
|
|
|
with col2: |
|
if prediction == 1: |
|
st.error("โ ๏ธ High Risk of Readmission") |
|
else: |
|
st.success("โ
Low Risk of Readmission") |
|
|
|
|
|
st.markdown(f""" |
|
<div style="text-align:center; padding: 20px; background-color: #f8f9fa; border-radius: 10px;"> |
|
<h3>๐ Readmission Probability</h3> |
|
<h2 style="color: {'red' if readmission_prob > 0.5 else 'green'};">{readmission_prob:.2%}</h2> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
elif page == "๐ Ensemble Prediction": |
|
|
|
|
|
ensemble_model = joblib.load("/Users/joaopimenta/Downloads/best_ensemble_model.pkl") |
|
|
|
|
|
|
|
models = ["XGBoost", "lightgbm", "MLP"] |
|
|
|
|
|
probabilities = [] |
|
for model in models: |
|
key = f"{model} probability" |
|
if key in st.session_state: |
|
try: |
|
prob = float(st.session_state[key]) |
|
probabilities.append(prob) |
|
except ValueError: |
|
st.error(f"โ ๏ธ Invalid probability value for {model}: {st.session_state[key]}") |
|
probabilities.append(None) |
|
else: |
|
probabilities.append(None) |
|
|
|
|
|
if None not in probabilities: |
|
st.write("### ๐ณ๏ธ Voting Process in Progress...") |
|
|
|
progress_bar = st.progress(0) |
|
voting_display = st.empty() |
|
|
|
votes = [] |
|
for i, (model, prob) in enumerate(zip(models, probabilities)): |
|
time.sleep(1) |
|
|
|
|
|
for _ in range(3): |
|
voting_display.markdown(f"โณ {model} is deciding...") |
|
time.sleep(0.5) |
|
voting_display.markdown("") |
|
time.sleep(0.5) |
|
|
|
|
|
if prob < 0.33: |
|
vote = "๐ข Low" |
|
elif prob < 0.46: |
|
vote = "๐ก Medium" |
|
else: |
|
vote = "๐ด High" |
|
|
|
votes.append(vote) |
|
voting_display.markdown(f"โ
**{model} voted: {vote}**") |
|
progress_bar.progress((i + 1) / len(models)) |
|
|
|
time.sleep(1) |
|
progress_bar.empty() |
|
|
|
|
|
final_df = pd.DataFrame([probabilities], columns=['probs', 'probs_lgb', 'probs_mlp']) |
|
final_df = final_df.astype(float) |
|
|
|
|
|
final_probability = ensemble_model.predict_proba(final_df)[:, 1][0] |
|
final_prediction = 1 if final_probability >= 0.25 else 0 |
|
|
|
|
|
st.markdown("---") |
|
if final_prediction == 1: |
|
st.markdown(f""" |
|
<div style="text-align: center; background-color: #ffdddd; padding: 15px; border-radius: 10px;"> |
|
<h2>๐จ <b>Final Prediction: 1</b> (Readmission Likely) </h2> |
|
<h3>๐ Probability: {final_probability:.2f} (Threshold: 0.25)</h3> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
else: |
|
st.markdown(f""" |
|
<div style="text-align: center; background-color: #ddffdd; padding: 15px; border-radius: 10px;"> |
|
<h2>โ
<b>Final Prediction: 0</b> (No Readmission Risk) </h2> |
|
<h3>๐ Probability: {final_probability:.2f} (Threshold: 0.25)</h3> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.write("### โ๏ธ Model Contribution to Final Decision") |
|
fig, ax = plt.subplots() |
|
ax.bar(models, probabilities, color=["blue", "green", "red"]) |
|
ax.set_ylabel("Probability") |
|
ax.set_title("Model Prediction Probabilities") |
|
st.pyplot(fig) |
|
|
|
|
|
st.write("### ๐ Voting Breakdown:") |
|
for model, vote in zip(models, votes): |
|
st.write(f"๐น {model}: **{vote}** (Prob: {probabilities[models.index(model)]:.2f})") |
|
|
|
else: |
|
st.warning("โ ๏ธ Some model predictions are missing. Please run all models before voting.") |