import streamlit as st |
import numpy as np |
import pandas as pd |
import joblib |
import torch |
from transformers import AutoTokenizer, AutoModel |
from xgboost import XGBClassifier |
from sklearn.preprocessing import StandardScaler |
from sklearn.decomposition import PCA |
from sklearn.metrics import precision_recall_curve, roc_curve, confusion_matrix, classification_report |
import matplotlib.pyplot as plt |
import shap |
import plotly.express as px |
import streamlit as st |
import pandas as pd |
import datetime |
import json |
import requests |
from streamlit_lottie import st_lottie |
import streamlit.components.v1 as components |
from streamlit_navigation_bar import st_navbar |
from transformers import AutoTokenizer, AutoModel |
import re |
from tqdm import tqdm |
import torch |
import os |
from hugchat.login import Login |
from hugchat import hugchat |
from transformers import pipeline |
from transformers import AutoTokenizer, AutoModelForTokenClassification |
import torch.nn as nn |
import time |
st.sidebar.title("๐ Menu") |
page = st.sidebar.radio( |
"Selecione uma opรงรฃo:", |
["๐ Home", "๐ Tabular Data", "๐ Clinical Text Notes", "๐ Ensemble Prediction"] |
) |
if page=="๐ Home": |
st.markdown(""" |
<style> |
.title { |
text-align: center; |
font-size: 36px; |
font-weight: bold; |
color: #2C3E50; |
} |
.subtitle { |
text-align: center; |
font-size: 22px; |
color: #7F8C8D; |
} |
.box { |
background-color: #ECF0F1; |
padding: 15px; |
border-radius: 10px; |
text-align: center; |
margin-bottom: 10px; |
font-size: 18px; |
} |
</style> |
""", unsafe_allow_html=True) |
st.markdown("<h1 class='title'>๐ AI Clinical Readmission Predictor</h1>", unsafe_allow_html=True) |
st.markdown("<h2 class='subtitle'>Using Machine Learning for Better Patient Outcomes</h2>", unsafe_allow_html=True) |
image_1 ='https://content.presspage.com/uploads/2110/4970f578-5f20-4675-acc2-3b2cda25fa96/1920_ai-machine-learning-cedars-sinai.jpg?10000' |
image_2 = 'https://med-tech.world/app/uploads/2024/10/AI-Hospitals.jpg.webp' |
st.image(image_2, width=1450) |
st.write("This app helps predict patient readmission risk using machine learning models. " |
"Upload data, analyze clinical notes, and see predictions from our ensemble model.") |
st.markdown("---") |
st.markdown("<h3 style='text-align: center;'>๐ Explore the App</h3>", unsafe_allow_html=True) |
elif page== "๐ Tabular Data": |
def load_lottie(url): |
response = requests.get(url) |
if response.status_code != 200: |
return None |
return response.json() |
lottie_hello = load_lottie("https://assets7.lottiefiles.com/packages/lf20_jcikwtux.json") |
if lottie_hello: |
st_lottie(lottie_hello, speed=1, loop=True, height=200) |
df = pd.read_csv('/Users/joaopimenta/Downloads/ensemble_test.csv') |
st.title('๐ฅ Hospital Readmission Prediction') |
st.markdown(""" |
<h3 style='text-align: center; color: gray;'>Predict ICU hospital readmission using Artificial Intelligence</h3> |
""", unsafe_allow_html=True) |
st.markdown("---") |
def get_age_group(age): |
"""Classify age into predefined groups with correct column names.""" |
if 36 <= age <= 50: |
return "age_group_36-50 (Middle-Aged Adults)" |
elif 51 <= age <= 65: |
return "age_group_51-65 (Older Middle-Aged Adults)" |
elif 66 <= age <= 80: |
return "age_group_66-80 (Senior Adults)" |
elif age >= 81: |
return "age_group_81+ (Elderly)" |
return "age_group_Below_36" |
def get_period(hour): |
"""Determine admission/discharge period.""" |
return "Morning" if 6 <= hour < 18 else "Night" |
st.subheader("๐ Select the admission's Characteristics") |
admission_type = st.selectbox("๐ Type of Admission", df.columns[df.columns.str.startswith('admission_type_')]) |
admission_location = st.selectbox("๐ Admission Location", df.columns[df.columns.str.startswith('admission_location_')]) |
discharge_location = st.selectbox("๐ฅ Discharge Location", df.columns[df.columns.str.startswith('discharge_location_')]) |
insurance = st.selectbox("๐ฐ Insurance Type", df.columns[df.columns.str.startswith('insurance_')]) |
st.sidebar.subheader("๐ Patient Information") |
language = st.sidebar.selectbox("๐ฃ Language", df.columns[df.columns.str.startswith('language_')]) |
marital_status = st.sidebar.selectbox("๐ Marital Status", df.columns[df.columns.str.startswith('marital_status_')]) |
race = st.sidebar.selectbox("๐ง Race", df.columns[df.columns.str.startswith('race_')]) |
sex = st.sidebar.selectbox("โง Sex", ['gender_M', 'gender_F']) |
age = st.sidebar.slider("๐
Age", 18, 100, 50) |
admission_time = st.time_input("โณ Admission Time", value=datetime.time(12, 0)) |
discharge_time = st.time_input("โณ Discharge Time", value=datetime.time(12, 0)) |
st.subheader("๐ Clinical Values") |
numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', 'drg_mortality', 'time_since_last_stay', |
'blood_cells', 'hemoglobin', 'glucose', 'creatine', 'plaquete'] |
numeric_inputs = {} |
cols = st.columns(len(numerical_features)) |
st.subheader("๐ General Hosptal Information") |
general_numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', |
'drg_mortality', 'time_since_last_stay'] |
general_inputs = {} |
cols = st.columns(3) |
for i, feature in enumerate(general_numerical_features): |
col_index = i % 3 |
min_val, max_val = df[feature].min(), df[feature].max() |
with cols[col_index]: |
general_inputs[feature] = st.slider( |
f"๐ {feature.replace('_', ' ').title()}", |
float(min_val), |
float(max_val), |
float((min_val + max_val) / 2) |
) |
st.subheader("๐งช Laboratory Test Results") |
lab_numerical_features = ['blood_cells', 'hemoglobin', 'glucose', |
'creatine', 'plaquete'] |
lab_inputs = {} |
lab_cols = st.columns(3) |
for i, feature in enumerate(lab_numerical_features): |
col_index = i % 3 |
min_val, max_val = df[feature].min(), df[feature].max() |
with lab_cols[col_index]: |
lab_inputs[feature] = st.slider( |
f"๐ฉธ {feature.replace('_', ' ').title()}", |
float(min_val), |
float(max_val), |
float((min_val + max_val) / 2) |
) |
min_val, max_val = df["cci_score"].min(), df["cci_score"].max() |
lab_inputs["cci_score"] = st.sidebar.slider( |
f"๐ CCI Score", |
float(min_val), |
float(max_val), |
float((min_val + max_val) / 2) |
) |
feature_vector = {col: 0 for col in df.columns} |
feature_vector.update({ |
admission_type: 1, |
admission_location: 1, |
discharge_location: 1, |
insurance: 1, |
language: 1, |
marital_status: 1, |
race: 1, |
"gender_M": 1 if sex == "gender_M" else 0, |
f"admit_period_{get_period(admission_time.hour)}": 1, |
f"discharge_period_{get_period(discharge_time.hour)}": 1 |
}) |
age_group = get_age_group(age) |
for group in [ |
"age_group_36-50 (Middle-Aged Adults)", |
"age_group_51-65 (Older Middle-Aged Adults)", |
"age_group_66-80 (Senior Adults)", |
"age_group_81+ (Elderly)" |
]: |
feature_vector[group] = 1 if group == age_group else 0 |
feature_vector.update(numeric_inputs) |
st.markdown("---") |
tabular_model_path = "/Users/joaopimenta/Downloads/final_xgboost_model.pkl" |
tabular_model = joblib.load(tabular_model_path) |
XGBoost Tabular Model loaded successfully!") |
expected_columns = [ |
col for col in df.columns if col not in ["Unnamed: 0", "subject_id", "hadm_id", "probs"] |
] |
age_group_mapping = { |
"age_group_36-50": "age_group_36-50 (Middle-Aged Adults)", |
"age_group_51-65": "age_group_51-65 (Older Middle-Aged Adults)", |
"age_group_66-80": "age_group_66-80 (Senior Adults)", |
"age_group_81+": "age_group_81+ (Elderly)", |
} |
feature_vector = {col: 0 for col in df.columns} |
feature_vector.update({ |
admission_type: 1, |
admission_location: 1, |
discharge_location: 1, |
insurance: 1, |
language: 1, |
marital_status: 1, |
race: 1, |
"gender_M": 1 if sex == "gender_M" else 0, |
f"admit_period_{get_period(admission_time.hour)}": 1, |
f"discharge_period_{get_period(discharge_time.hour)}": 1 |
}) |
age_group = get_age_group(age) |
for group in [ |
"age_group_36-50 (Middle-Aged Adults)", |
"age_group_51-65 (Older Middle-Aged Adults)", |
"age_group_66-80 (Senior Adults)", |
"age_group_81+ (Elderly)" |
]: |
feature_vector[group] = 1 if group == age_group else 0 |
feature_vector.update(general_inputs) |
feature_vector.update(lab_inputs) |
fixed_feature_vector = {age_group_mapping.get(k, k): v for k, v in feature_vector.items()} |
feature_df = pd.DataFrame([fixed_feature_vector]).reindex(columns=expected_columns, fill_value=0) |
st.write(feature_df) |
prediction_proba = tabular_model.predict_proba(feature_df)[:, 1] |
probability = float(prediction_proba[0]) |
st.session_state["XGBoost probability"] = probability |
prediction = (prediction_proba >= 0.5).astype(int) |
import shap |
import matplotlib.pyplot as plt |
import streamlit.components.v1 as components |
st.write(f"Raw Prediction Probability: {probability:.4f}") |
if st.button("๐ Predict Readmission"): |
with st.spinner("๐ Processing Prediction..."): |
st.subheader("๐ฏ Prediction Results") |
col1, col2 = st.columns(2) |
with col1: |
st.metric(label="๐งฎ Readmission Probability", value=f"{probability:.2%}") |
with col2: |
if prediction == 1: |
st.error("โ ๏ธ High Risk of Readmission") |
else: |
Low Risk of Readmission") |
if st.button("๐ Feature Importance for Prediction"): |
st.metric(label="๐งฎ Readmission Probability", value=f"{probability:.2%}") |
explainer = shap.TreeExplainer(tabular_model) |
shap_values = explainer.shap_values(feature_df) |
shap_df = pd.DataFrame({ |
"Feature": feature_df.columns, |
"SHAP Value": shap_values[0] |
}) |
shap_df["abs_SHAP"] = shap_df["SHAP Value"].abs() |
shap_df = shap_df.sort_values(by="abs_SHAP", ascending=False).head(10) |
top_features = sorted(zip(shap_df['Feature'], shap_df['SHAP Value']), key=lambda x: abs(x[1]), reverse=True) |
top_factors = "\n".join([f"- {feat}: {round(value, 2)} impact" for feat, value in top_features]) |
EMAIL = st.secrets["email"] |
PASSWD = st.secrets["passwd"] |
cookie_path_dir = "./cookies_snapshot" |
try: |
sign = Login(EMAIL, PASSWD) |
cookies = sign.login(cookie_dir_path=cookie_path_dir, save_cookies=True) |
sign.saveCookiesToDir(cookie_path_dir) |
except Exception as e: |
st.error(f"โ Login to HuggingChat failed. Error: {e}") |
st.stop() |
chatbot = hugchat.ChatBot(cookies=cookies.get_dict()) |
st.title("๐ฉบ AI-Powered Patient Readmission Analysis") |
hugging_prompt = f""" |
A hospital AI model predicts patient readmission based on the following feature impacts: |
{top_factors} |
Can you explain why the model made this decision? Specifically, what were the key characteristics of the patient or their admission that influenced the modelโs prediction the most? |
""" |
with st.spinner("๐ค Analyzing..."): |
try: |
response = chatbot.chat(hugging_prompt) |
ai_output = response |
with st.chat_message("assistant"): |
st.markdown(f"**๐ก AI Explanation:**\n\n{ai_output}") |
except Exception as e: |
st.error(f"โ ๏ธ Error retrieving response: {e}") |
st.stop() |
with st.expander("๐ Click to see detailed feature impacts"): |
st.markdown(f"```{top_factors}```") |
fig, ax = plt.subplots(figsize=(8, 6)) |
shap.bar_plot(shap_df["SHAP Value"].values, shap_df["Feature"].values) |
st.pyplot(fig) |
st.subheader("๐ฏ SHAP Force Plot (How Features Affected the Prediction)") |
force_plot = shap.force_plot( |
explainer.expected_value, shap_values[0], feature_df.iloc[0], matplotlib=False |
) |
shap_html = f"<head>{shap.getjs()}</head><body>{force_plot.html()}</body>" |
components.html(shap_html, height=400) |
elif page == "๐ Clinical Text Notes": |
st.subheader("๐ Clinical Text Note") |
def clean_text(text): |
"""Cleans input text by removing non-ASCII characters, extra spaces, and unwanted symbols.""" |
text = re.sub(r"[^\x20-\x7E]", " ", text) |
text = re.sub(r"_{2,}", "", text) |
text = re.sub(r"\s+", " ", text) |
text = re.sub(r"[^\w\s.,:;*%()\[\]-]", "", text) |
return text.lower().strip() |
import re |
def extract_fields(text): |
"""Extracts key fields from clinical notes using regex patterns.""" |
patterns = { |
"Discharge Medications": r"Discharge Medications[:\-]?\s*(.+?)\s+(?:Discharge Disposition|Discharge Condition|Discharge Instructions|Followup Instructions|$)", |
"Discharge Diagnosis": r"Discharge Diagnosis[:\-]?\s*(.+?)\s+(?:Discharge Condition|Discharge Medications|Discharge Instructions|Followup Instructions|$)", |
"Discharge Instructions": r"Discharge Instructions[:\-]?\s*(.*?)\s+(?:Followup Instructions|Discharge Disposition|Discharge Condition|$)", |
"History of Present Illness": r"History of Present Illness[:\-]?\s*(.+?)\s+(?:Past Medical History|Social History|Family History|Physical Exam|$)", |
"Past Medical History": r"Past Medical History[:\-]?\s*(.+?)\s+(?:Social History|Family History|Physical Exam|$)" |
} |
extracted_data = {} |
for field, pattern in patterns.items(): |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) |
if match: |
extracted_data[field] = match.group(1).strip() |
return extracted_data |
def extract_features(texts, model, tokenizer, device, batch_size=8): |
"""Extracts CLS token embeddings from the Clinical-Longformer model.""" |
all_features = [] |
for i in range(0, len(texts), batch_size): |
batch_texts = texts[i:i+batch_size] |
inputs = tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=4096).to(device) |
global_attention_mask = torch.zeros_like(inputs["input_ids"]).to(device) |
global_attention_mask[:, 0] = 1 |
with torch.no_grad(): |
outputs = model(**inputs, global_attention_mask=global_attention_mask) |
all_features.append(outputs.last_hidden_state[:, 0, :]) |
return torch.cat(all_features, dim=0) |
def extract_entities(text, pipe, entity_group): |
"""Extracts specific entities from the clinical note using a NER pipeline.""" |
entities = pipe(text) |
return [ent['word'] for ent in entities if ent['entity_group'] == entity_group] or ["No relevant entities found"] |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
@st.cache_resource() |
def load_models(): |
"""Loads transformer models for text processing and NER.""" |
longformer_tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer") |
longformer_model = AutoModel.from_pretrained("yikuan8/Clinical-Longformer").to(device).eval() |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") |
return longformer_tokenizer, longformer_model, ner_pipe |
longformer_tokenizer, longformer_model, ner_pipe = load_models() |
clinical_note = st.text_area("โ๏ธ Enter Clinical Note", placeholder="Write the clinical note here...") |
if clinical_note: |
cleaned_note = clean_text(clinical_note) |
extracted_data = extract_fields(cleaned_note) |
st.write("### Extracted Fields") |
st.write(extracted_data) |
with st.spinner("๐ Extracting embeddings..."): |
embeddings = extract_features([cleaned_note], longformer_model, longformer_tokenizer, device) |
class RobustMLPClassifier(nn.Module): |
def __init__(self, input_dim, hidden_dims=[256, 128, 64], dropout=0.3, activation=nn.ReLU()): |
super(RobustMLPClassifier, self).__init__() |
layers = [] |
current_dim = input_dim |
for h in hidden_dims: |
layers.append(nn.Linear(current_dim, h)) |
layers.append(nn.BatchNorm1d(h)) |
layers.append(activation) |
layers.append(nn.Dropout(dropout)) |
current_dim = h |
layers.append(nn.Linear(current_dim, 1)) |
self.net = nn.Sequential(*layers) |
def forward(self, x): |
return self.net(x) |
mlp_model_path = "/Users/joaopimenta/Downloads/Capstone/best_mlp_model_full.pth" |
pca_path = "/Users/joaopimenta/Downloads/Capstone/best_pca_model.pkl" |
best_mlp_model = torch.load(mlp_model_path) |
best_mlp_model.to(device) |
best_mlp_model.eval() |
pca = joblib.load(pca_path) |
def predict_readmission(texts): |
"""Predicts hospital readmission probability using Clinical-Longformer embeddings and MLP.""" |
embeddings = extract_features(texts, longformer_model, longformer_tokenizer, device) |
embeddings_pca = pca.transform(embeddings.cpu().numpy()) |
inputs = torch.FloatTensor(embeddings_pca).to(device) |
with torch.no_grad(): |
logits = best_mlp_model(inputs) |
probabilities = torch.sigmoid(logits).cpu().numpy() |
return probabilities |
with st.spinner("๐ Identifying medical entities..."): |
extracted_data["Extracted Medications"] = extract_entities( |
extracted_data.get("Discharge Medications", ""), ner_pipe, "Medication" |
) |
extracted_data["Extracted Diseases"] = extract_entities( |
extracted_data.get("Discharge Diagnosis", ""), ner_pipe, "Disease_disorder" |
) |
extracted_data["Extracted Diseases (Past Medical History)"] = extract_entities( |
extracted_data.get("Past Medical History", ""), ner_pipe, "Disease_disorder" |
) |
extracted_data["Extracted Diseases (History of Present Illness)"] = extract_entities( |
extracted_data.get("History of Present Illness", ""), ner_pipe, "Disease_disorder" |
) |
extracted_data["Extracted Symptoms"] = extract_entities( |
extracted_data.get("Review of Systems", "") + " " + extracted_data.get("History of Present Illness", ""), |
ner_pipe, "Sign_symptom" |
) |
def clean_entities(entities): |
"""Reconstruct fragmented tokens and remove duplicates.""" |
cleaned = [] |
temp = "" |
for entity in entities: |
if entity.startswith("##"): |
temp += entity.replace("##", "") |
else: |
if temp: |
cleaned.append(temp) |
temp = entity |
if temp: |
cleaned.append(temp) |
cleaned = [word for word in cleaned if len(word) > 2 and not re.match(r"^[\W_]+$", word)] |
return sorted(set(cleaned)) |
diseases_cleaned = clean_entities( |
extracted_data.get("Extracted Diseases", []) + |
extracted_data.get("Extracted Diseases (Past Medical History)", []) + |
extracted_data.get("Extracted Diseases (History of Present Illness)", []) |
) |
medications_cleaned = clean_entities(extracted_data.get("Extracted Medications", [])) |
extracted_data["Extracted Medications Cleaned"] = medications_cleaned |
symptoms_cleaned = clean_entities(extracted_data.get("Extracted Symptoms", [])) |
def display_list(title, items, icon="๐"): |
"""Display extracted medical entities in an expandable list.""" |
with st.expander(f"**{title} ({len(items)})**"): |
if items: |
for item in items: |
st.markdown(f"- {icon} **{item}**") |
else: |
st.markdown("_No information available._") |
st.markdown("## ๐ฅ **Patient Medical Analysis**") |
st.markdown("---") |
col1, col2, col3 = st.columns(3) |
num_medications = len(medications_cleaned ) |
col1.metric(label="๐ Total Medications", value=num_medications) |
num_diseases = len(diseases_cleaned) |
col2.metric(label="๐ฆ Total Diseases", value=num_diseases) |
num_symptoms = len(symptoms_cleaned) |
col3.metric(label="๐ค Total Symptoms", value=num_symptoms) |
st.markdown("---") |
col1, col2 = st.columns(2) |
with col1: |
st.markdown("### ๐ **Medications**") |
display_list("Medication List", medications_cleaned , icon="๐") |
with col2: |
st.markdown("### ๐ฆ **Diseases**") |
display_list("Disease List", diseases_cleaned, icon="๐ฆ ") |
st.markdown("### ๐ค **Symptoms**") |
display_list("Symptoms List", symptoms_cleaned, icon="๐ค") |
st.markdown("---") |
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer") |
def count_tokens(text): |
tokens = tokenizer.tokenize(text) |
return len(tokens) |
def trunced_text(nr): |
return 1 if nr > 4096 else 0 |
disease_synonyms = { |
"Pneumonia": ["pneumonia", "pneumonitis"], |
"Diabetes": ["diabetes", "diabetes mellitus", "dm"], |
"CHF": ["CHF", "congestive heart failure", "heart failure"], |
"Septicemia": ["septicemia", "sepsis", "blood infection"], |
"Cirrhosis": ["cirrhosis", "liver cirrhosis", "hepatic cirrhosis"], |
"COPD": ["COPD", "chronic obstructive pulmonary disease"], |
"Renal_Failure": ["renal failure", "kidney failure", "chronic kidney disease", "CKD"] |
} |
extracted_data = extract_fields(clinical_note) |
number_of_tokens = count_tokens(clinical_note) |
number_of_tokens_med = count_tokens(extracted_data.get("Discharge Medications", "")) |
number_of_tokens_dis = count_tokens(extracted_data.get("Discharge Diagnosis", "")) |
trunced = trunced_text(number_of_tokens) |
full_diagnosis_text = extracted_data.get("Discharge Diagnosis", "").lower() |
def check_disease_presence(disease_list, text): |
return int(any(re.search(rf"\b{synonym}\b", text, re.IGNORECASE) for synonym in disease_list)) |
disease_flags = {disease: check_disease_presence(synonyms, full_diagnosis_text) |
for disease, synonyms in disease_synonyms.items()} |
disease_flags["total_conditions"] = sum(disease_flags.values()) |
df = pd.DataFrame([{ |
'number_of_tokens_dis': number_of_tokens_dis, |
'number_of_tokens': number_of_tokens, |
'number_of_tokens_med': number_of_tokens_med, |
'diagnostic_count': num_diseases, |
'total_conditions': disease_flags["total_conditions"], |
'trunced': trunced, |
**{disease: disease_flags[disease] for disease in disease_synonyms.keys()} |
}]) |
light_path = '/Users/joaopimenta/Downloads/best_lgbm_model.pkl' |
light_model = joblib.load(light_path) |
model_features = light_model.feature_name_ |
missing_features = [feat for feat in model_features if feat not in df.columns] |
if missing_features: |
st.write(f"โ ๏ธ Warning: Missing features in df: {missing_features}") |
for feat in missing_features: |
df[feat] = 0 |
df = df[model_features] |
X = df.to_numpy() |
light_probability = light_model.predict_proba(X)[:, 1] |
st.session_state["lightgbm probability"] = light_probability |
if st.button("๐ Predict Readmission"): |
with st.spinner("๐ Extracting embeddings and predicting readmission..."): |
readmission_prob = predict_readmission([cleaned_note])[0][0] |
st.session_state["MLP probability"] = readmission_prob |
prediction = 1 if readmission_prob > 0.5 else 0 |
st.subheader("๐ฏ Prediction Results") |
col1, col2 = st.columns(2) |
with col1: |
st.metric(label="๐งฎ Readmission Probability", value=f"{readmission_prob:.2%}") |
with col2: |
if prediction == 1: |
st.error("โ ๏ธ High Risk of Readmission") |
else: |
Low Risk of Readmission") |
st.markdown(f""" |
<div style="text-align:center; padding: 20px; background-color: #f8f9fa; border-radius: 10px;"> |
<h3>๐ Readmission Probability</h3> |
<h2 style="color: {'red' if readmission_prob > 0.5 else 'green'};">{readmission_prob:.2%}</h2> |
</div> |
""", unsafe_allow_html=True) |
elif page == "๐ Ensemble Prediction": |
ensemble_model = joblib.load("/Users/joaopimenta/Downloads/best_ensemble_model.pkl") |
models = ["XGBoost", "lightgbm", "MLP"] |
probabilities = [] |
for model in models: |
key = f"{model} probability" |
if key in st.session_state: |
try: |
prob = float(st.session_state[key]) |
probabilities.append(prob) |
except ValueError: |
st.error(f"โ ๏ธ Invalid probability value for {model}: {st.session_state[key]}") |
probabilities.append(None) |
else: |
probabilities.append(None) |
if None not in probabilities: |
st.write("### ๐ณ๏ธ Voting Process in Progress...") |
progress_bar = st.progress(0) |
voting_display = st.empty() |
votes = [] |
for i, (model, prob) in enumerate(zip(models, probabilities)): |
time.sleep(1) |
for _ in range(3): |
voting_display.markdown(f"โณ {model} is deciding...") |
time.sleep(0.5) |
voting_display.markdown("") |
time.sleep(0.5) |
if prob < 0.33: |
vote = "๐ข Low" |
elif prob < 0.46: |
vote = "๐ก Medium" |
else: |
vote = "๐ด High" |
votes.append(vote) |
**{model} voted: {vote}**") |
progress_bar.progress((i + 1) / len(models)) |
time.sleep(1) |
progress_bar.empty() |
final_df = pd.DataFrame([probabilities], columns=['probs', 'probs_lgb', 'probs_mlp']) |
final_df = final_df.astype(float) |
final_probability = ensemble_model.predict_proba(final_df)[:, 1][0] |
final_prediction = 1 if final_probability >= 0.25 else 0 |
st.markdown("---") |
if final_prediction == 1: |
st.markdown(f""" |
<div style="text-align: center; background-color: #ffdddd; padding: 15px; border-radius: 10px;"> |
<h2>๐จ <b>Final Prediction: 1</b> (Readmission Likely) </h2> |
<h3>๐ Probability: {final_probability:.2f} (Threshold: 0.25)</h3> |
</div> |
""", unsafe_allow_html=True) |
else: |
st.markdown(f""" |
<div style="text-align: center; background-color: #ddffdd; padding: 15px; border-radius: 10px;"> |
<b>Final Prediction: 0</b> (No Readmission Risk) </h2> |
<h3>๐ Probability: {final_probability:.2f} (Threshold: 0.25)</h3> |
</div> |
""", unsafe_allow_html=True) |
st.write("### โ๏ธ Model Contribution to Final Decision") |
fig, ax = plt.subplots() |
ax.bar(models, probabilities, color=["blue", "green", "red"]) |
ax.set_ylabel("Probability") |
ax.set_title("Model Prediction Probabilities") |
st.pyplot(fig) |
st.write("### ๐ Voting Breakdown:") |
for model, vote in zip(models, votes): |
st.write(f"๐น {model}: **{vote}** (Prob: {probabilities[models.index(model)]:.2f})") |
else: |
st.warning("โ ๏ธ Some model predictions are missing. Please run all models before voting.") |