import streamlit as st import numpy as np import pandas as pd import joblib import torch from transformers import AutoTokenizer, AutoModel from xgboost import XGBClassifier from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.metrics import precision_recall_curve, roc_curve, confusion_matrix, classification_report import matplotlib.pyplot as plt import shap import plotly.express as px import streamlit as st import pandas as pd import datetime import json import requests from streamlit_lottie import st_lottie import streamlit.components.v1 as components from streamlit_navigation_bar import st_navbar from transformers import AutoTokenizer, AutoModel import re from tqdm import tqdm import torch import os from hugchat.login import Login from hugchat import hugchat from transformers import pipeline from transformers import AutoTokenizer, AutoModelForTokenClassification import torch.nn as nn import time # Criar o menu na barra lateral st.sidebar.title("๐Ÿ“Œ Menu") page = st.sidebar.radio( "Selecione uma opรงรฃo:", ["๐Ÿ  Home", "๐Ÿ“Š Tabular Data", "๐Ÿ“ Clinical Text Notes", "๐Ÿ”€ Ensemble Prediction"] ) if page=="๐Ÿ  Home": st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown("

๐Ÿ“Š AI Clinical Readmission Predictor

", unsafe_allow_html=True) st.markdown("

Using Machine Learning for Better Patient Outcomes

", unsafe_allow_html=True) image_1 ='https://content.presspage.com/uploads/2110/4970f578-5f20-4675-acc2-3b2cda25fa96/1920_ai-machine-learning-cedars-sinai.jpg?10000' image_2 = 'https://med-tech.world/app/uploads/2024/10/AI-Hospitals.jpg.webp' st.image(image_2, width=1450) # Hospital Icon st.write("This app helps predict patient readmission risk using machine learning models. " "Upload data, analyze clinical notes, and see predictions from our ensemble model.") # Navigation Buttons st.markdown("---") st.markdown("

๐Ÿš€ Explore the App

", unsafe_allow_html=True) elif page== "๐Ÿ“Š Tabular Data": # Function to load Lottie animation def load_lottie(url): response = requests.get(url) if response.status_code != 200: return None return response.json() # Load Lottie Animation lottie_hello = load_lottie("https://assets7.lottiefiles.com/packages/lf20_jcikwtux.json") if lottie_hello: st_lottie(lottie_hello, speed=1, loop=True, height=200) # Load dataset df = pd.read_csv('/Users/joaopimenta/Downloads/ensemble_test.csv') # Streamlit App Header st.title('๐Ÿฅ Hospital Readmission Prediction') st.markdown("""

Predict ICU hospital readmission using Artificial Intelligence

""", unsafe_allow_html=True) st.markdown("---") # Helper Functions def get_age_group(age): """Classify age into predefined groups with correct column names.""" if 36 <= age <= 50: return "age_group_36-50 (Middle-Aged Adults)" elif 51 <= age <= 65: return "age_group_51-65 (Older Middle-Aged Adults)" elif 66 <= age <= 80: return "age_group_66-80 (Senior Adults)" elif age >= 81: return "age_group_81+ (Elderly)" return "age_group_Below_36" def get_period(hour): """Determine admission/discharge period.""" return "Morning" if 6 <= hour < 18 else "Night" # **User Inputs** st.subheader("๐Ÿ“Œ Select the admission's Characteristics") admission_type = st.selectbox("๐Ÿ›‘ Type of Admission", df.columns[df.columns.str.startswith('admission_type_')]) admission_location = st.selectbox("๐Ÿ“ Admission Location", df.columns[df.columns.str.startswith('admission_location_')]) discharge_location = st.selectbox("๐Ÿฅ Discharge Location", df.columns[df.columns.str.startswith('discharge_location_')]) insurance = st.selectbox("๐Ÿ’ฐ Insurance Type", df.columns[df.columns.str.startswith('insurance_')]) st.sidebar.subheader("๐Ÿ“Š Patient Information") language = st.sidebar.selectbox("๐Ÿ—ฃ Language", df.columns[df.columns.str.startswith('language_')]) marital_status = st.sidebar.selectbox("๐Ÿ’ Marital Status", df.columns[df.columns.str.startswith('marital_status_')]) race = st.sidebar.selectbox("๐Ÿง‘ Race", df.columns[df.columns.str.startswith('race_')]) sex = st.sidebar.selectbox("โšง Sex", ['gender_M', 'gender_F']) age = st.sidebar.slider("๐Ÿ“… Age", 18, 100, 50) admission_time = st.time_input("โณ Admission Time", value=datetime.time(12, 0)) discharge_time = st.time_input("โณ Discharge Time", value=datetime.time(12, 0)) # Laboratory & Clinical Values st.subheader("๐Ÿ“ˆ Clinical Values") numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', 'drg_mortality', 'time_since_last_stay', 'blood_cells', 'hemoglobin', 'glucose', 'creatine', 'plaquete'] numeric_inputs = {} cols = st.columns(len(numerical_features)) # General Numerical Values st.subheader("๐Ÿ“Š General Hosptal Information") general_numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', 'drg_mortality', 'time_since_last_stay'] general_inputs = {} cols = st.columns(3) # Three columns for general values for i, feature in enumerate(general_numerical_features): col_index = i % 3 # Distribute across columns min_val, max_val = df[feature].min(), df[feature].max() with cols[col_index]: general_inputs[feature] = st.slider( f"๐Ÿ“Œ {feature.replace('_', ' ').title()}", float(min_val), float(max_val), float((min_val + max_val) / 2) ) # Laboratory Values st.subheader("๐Ÿงช Laboratory Test Results") lab_numerical_features = ['blood_cells', 'hemoglobin', 'glucose', 'creatine', 'plaquete'] lab_inputs = {} lab_cols = st.columns(3) # Three columns for lab values for i, feature in enumerate(lab_numerical_features): col_index = i % 3 # Distribute across columns min_val, max_val = df[feature].min(), df[feature].max() with lab_cols[col_index]: lab_inputs[feature] = st.slider( f"๐Ÿฉธ {feature.replace('_', ' ').title()}", float(min_val), float(max_val), float((min_val + max_val) / 2) ) min_val, max_val = df["cci_score"].min(), df["cci_score"].max() lab_inputs["cci_score"] = st.sidebar.slider( f"๐Ÿ“Œ CCI Score", float(min_val), float(max_val), float((min_val + max_val) / 2) ) # Process Inputs into Features feature_vector = {col: 0 for col in df.columns} feature_vector.update({ admission_type: 1, admission_location: 1, discharge_location: 1, insurance: 1, language: 1, marital_status: 1, race: 1, "gender_M": 1 if sex == "gender_M" else 0, f"admit_period_{get_period(admission_time.hour)}": 1, f"discharge_period_{get_period(discharge_time.hour)}": 1 }) age_group = get_age_group(age) # This function now returns correct dataset column names # Use the exact column names from the dataset for group in [ "age_group_36-50 (Middle-Aged Adults)", "age_group_51-65 (Older Middle-Aged Adults)", "age_group_66-80 (Senior Adults)", "age_group_81+ (Elderly)" ]: feature_vector[group] = 1 if group == age_group else 0 # Set selected group to 1, others to 0 feature_vector.update(numeric_inputs) # Display Processed Data st.markdown("---") # Load XGBoost model tabular_model_path = "/Users/joaopimenta/Downloads/final_xgboost_model.pkl" tabular_model = joblib.load(tabular_model_path) print("โœ… XGBoost Tabular Model loaded successfully!") # Load dataset columns (use the same order as training) expected_columns = [ col for col in df.columns if col not in ["Unnamed: 0", "subject_id", "hadm_id", "probs"] ] # Define correct dataset column names for age groups age_group_mapping = { "age_group_36-50": "age_group_36-50 (Middle-Aged Adults)", "age_group_51-65": "age_group_51-65 (Older Middle-Aged Adults)", "age_group_66-80": "age_group_66-80 (Senior Adults)", "age_group_81+": "age_group_81+ (Elderly)", } # Process Inputs into Features feature_vector = {col: 0 for col in df.columns} # Set selected categorical features to 1 feature_vector.update({ admission_type: 1, admission_location: 1, discharge_location: 1, insurance: 1, language: 1, marital_status: 1, race: 1, "gender_M": 1 if sex == "gender_M" else 0, f"admit_period_{get_period(admission_time.hour)}": 1, f"discharge_period_{get_period(discharge_time.hour)}": 1 }) # Set correct age group age_group = get_age_group(age) for group in [ "age_group_36-50 (Middle-Aged Adults)", "age_group_51-65 (Older Middle-Aged Adults)", "age_group_66-80 (Senior Adults)", "age_group_81+ (Elderly)" ]: feature_vector[group] = 1 if group == age_group else 0 # Update with numerical inputs feature_vector.update(general_inputs) feature_vector.update(lab_inputs) # Ensure feature order matches expected model input fixed_feature_vector = {age_group_mapping.get(k, k): v for k, v in feature_vector.items()} feature_df = pd.DataFrame([fixed_feature_vector]).reindex(columns=expected_columns, fill_value=0) st.write(feature_df) # Predict probability of readmission prediction_proba = tabular_model.predict_proba(feature_df)[:, 1] probability = float(prediction_proba[0]) # Convert NumPy array to scalar st.session_state["XGBoost probability"] = probability prediction = (prediction_proba >= 0.5).astype(int) import shap import matplotlib.pyplot as plt import streamlit.components.v1 as components # Required for displaying SHAP force plot st.write(f"Raw Prediction Probability: {probability:.4f}") # Prediction Button if st.button("๐Ÿš€ Predict Readmission"): with st.spinner("๐Ÿ” Processing Prediction..."): st.subheader("๐ŸŽฏ Prediction Results") col1, col2 = st.columns(2) with col1: st.metric(label="๐Ÿงฎ Readmission Probability", value=f"{probability:.2%}") with col2: if prediction == 1: st.error("โš ๏ธ High Risk of Readmission") else: st.success("โœ… Low Risk of Readmission") # Feature Importance Button if st.button("๐Ÿ” Feature Importance for Prediction"): st.metric(label="๐Ÿงฎ Readmission Probability", value=f"{probability:.2%}") # โœ… Initialize SHAP Explainer for XGBoost explainer = shap.TreeExplainer(tabular_model) shap_values = explainer.shap_values(feature_df) # SHAP values for all samples # โœ… Convert SHAP values into a DataFrame (Sorting First) shap_df = pd.DataFrame({ "Feature": feature_df.columns, "SHAP Value": shap_values[0] # SHAP values for the first instance }) # โœ… Select **Top 10 Most Important Features** (Sorted by Absolute SHAP Value) shap_df["abs_SHAP"] = shap_df["SHAP Value"].abs() # Add column with absolute values shap_df = shap_df.sort_values(by="abs_SHAP", ascending=False).head(10) # Top 10 # Get top features and their SHAP impact values (shap_df assumed to be available) top_features = sorted(zip(shap_df['Feature'], shap_df['SHAP Value']), key=lambda x: abs(x[1]), reverse=True) # Create a formatted string for `top_factors` to be shown in the UI top_factors = "\n".join([f"- {feat}: {round(value, 2)} impact" for feat, value in top_features]) # โœ… Login to HuggingChat (credentials hard-coded here) EMAIL = st.secrets["email"] PASSWD = st.secrets["passwd"] cookie_path_dir = "./cookies_snapshot" # Log in and save cookies try: sign = Login(EMAIL, PASSWD) cookies = sign.login(cookie_dir_path=cookie_path_dir, save_cookies=True) sign.saveCookiesToDir(cookie_path_dir) except Exception as e: st.error(f"โŒ Login to HuggingChat failed. Error: {e}") st.stop() # โœ… Create HuggingChat bot instance chatbot = hugchat.ChatBot(cookies=cookies.get_dict()) # ๐ŸŽญ **Streamlit UI** st.title("๐Ÿฉบ AI-Powered Patient Readmission Analysis") # โœ… Construct the AI query with real SHAP feature impacts # โœ… Construct the AI query with real SHAP feature impacts hugging_prompt = f""" A hospital AI model predicts patient readmission based on the following feature impacts: {top_factors} Can you explain why the model made this decision? Specifically, what were the key characteristics of the patient or their admission that influenced the modelโ€™s prediction the most? """ # โœ… Query HuggingChat with st.spinner("๐Ÿค– Analyzing..."): try: response = chatbot.chat(hugging_prompt) # Corrected method to 'chat' instead of 'query' ai_output = response # Extract AI response # ๐ŸŽญ **Show AI Response in a Stylish Chat Format** with st.chat_message("assistant"): st.markdown(f"**๐Ÿ’ก AI Explanation:**\n\n{ai_output}") except Exception as e: st.error(f"โš ๏ธ Error retrieving response: {e}") st.stop() # โœ… **Expand for SHAP Feature Details** with st.expander("๐Ÿ“œ Click to see detailed feature impacts"): st.markdown(f"```{top_factors}```") # Show Top 10 Features #st.write(shap_df[["Feature", "SHAP Value"]]) # Display only relevant columns # โœ… SHAP Bar Plot (Corrected for Top 10 Selection) fig, ax = plt.subplots(figsize=(8, 6)) shap.bar_plot(shap_df["SHAP Value"].values, shap_df["Feature"].values) # Correct Top 10 st.pyplot(fig) # ๐ŸŽฏ SHAP Force Plot (How Features Affected the Prediction) st.subheader("๐ŸŽฏ SHAP Force Plot (How Features Affected the Prediction)") # โœ… Fix: Use explainer.expected_value (single scalar) force_plot = shap.force_plot( explainer.expected_value, shap_values[0], feature_df.iloc[0], matplotlib=False ) # โœ… Convert SHAP force plot to HTML shap_html = f"{shap.getjs()}{force_plot.html()}" # โœ… Render SHAP force plot in Streamlit components.html(shap_html, height=400) elif page == "๐Ÿ“ Clinical Text Notes": # Set Streamlit Page Title st.subheader("๐Ÿ“ Clinical Text Note") # Utility Functions def clean_text(text): """Cleans input text by removing non-ASCII characters, extra spaces, and unwanted symbols.""" text = re.sub(r"[^\x20-\x7E]", " ", text) text = re.sub(r"_{2,}", "", text) text = re.sub(r"\s+", " ", text) text = re.sub(r"[^\w\s.,:;*%()\[\]-]", "", text) return text.lower().strip() import re def extract_fields(text): """Extracts key fields from clinical notes using regex patterns.""" patterns = { "Discharge Medications": r"Discharge Medications[:\-]?\s*(.+?)\s+(?:Discharge Disposition|Discharge Condition|Discharge Instructions|Followup Instructions|$)", "Discharge Diagnosis": r"Discharge Diagnosis[:\-]?\s*(.+?)\s+(?:Discharge Condition|Discharge Medications|Discharge Instructions|Followup Instructions|$)", "Discharge Instructions": r"Discharge Instructions[:\-]?\s*(.*?)\s+(?:Followup Instructions|Discharge Disposition|Discharge Condition|$)", "History of Present Illness": r"History of Present Illness[:\-]?\s*(.+?)\s+(?:Past Medical History|Social History|Family History|Physical Exam|$)", "Past Medical History": r"Past Medical History[:\-]?\s*(.+?)\s+(?:Social History|Family History|Physical Exam|$)" } extracted_data = {} for field, pattern in patterns.items(): match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) if match: extracted_data[field] = match.group(1).strip() return extracted_data def extract_features(texts, model, tokenizer, device, batch_size=8): """Extracts CLS token embeddings from the Clinical-Longformer model.""" all_features = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] inputs = tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=4096).to(device) global_attention_mask = torch.zeros_like(inputs["input_ids"]).to(device) global_attention_mask[:, 0] = 1 # Set global attention for CLS token with torch.no_grad(): outputs = model(**inputs, global_attention_mask=global_attention_mask) all_features.append(outputs.last_hidden_state[:, 0, :]) return torch.cat(all_features, dim=0) def extract_entities(text, pipe, entity_group): """Extracts specific entities from the clinical note using a NER pipeline.""" entities = pipe(text) return [ent['word'] for ent in entities if ent['entity_group'] == entity_group] or ["No relevant entities found"] # Load Model and Tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @st.cache_resource() def load_models(): """Loads transformer models for text processing and NER.""" longformer_tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer") longformer_model = AutoModel.from_pretrained("yikuan8/Clinical-Longformer").to(device).eval() ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") return longformer_tokenizer, longformer_model, ner_pipe longformer_tokenizer, longformer_model, ner_pipe = load_models() # Text Input clinical_note = st.text_area("โœ๏ธ Enter Clinical Note", placeholder="Write the clinical note here...") if clinical_note: cleaned_note = clean_text(clinical_note) #st.write("### ๐Ÿ“ Cleaned Clinical Note:") #st.write(cleaned_note) # Extract Fields extracted_data = extract_fields(cleaned_note) st.write("### Extracted Fields") st.write(extracted_data) # Extract Embeddings with st.spinner("๐Ÿ” Extracting embeddings..."): embeddings = extract_features([cleaned_note], longformer_model, longformer_tokenizer, device) #st.write("### Extracted Embeddings") #st.write(embeddings) # Definir a classe RobustMLPClassifier class RobustMLPClassifier(nn.Module): def __init__(self, input_dim, hidden_dims=[256, 128, 64], dropout=0.3, activation=nn.ReLU()): super(RobustMLPClassifier, self).__init__() layers = [] current_dim = input_dim for h in hidden_dims: layers.append(nn.Linear(current_dim, h)) layers.append(nn.BatchNorm1d(h)) layers.append(activation) layers.append(nn.Dropout(dropout)) current_dim = h layers.append(nn.Linear(current_dim, 1)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) # --- Load MLP Model and PCA --- mlp_model_path = "/Users/joaopimenta/Downloads/Capstone/best_mlp_model_full.pth" pca_path = "/Users/joaopimenta/Downloads/Capstone/best_pca_model.pkl" best_mlp_model = torch.load(mlp_model_path) best_mlp_model.to(device) best_mlp_model.eval() pca = joblib.load(pca_path) def predict_readmission(texts): """Predicts hospital readmission probability using Clinical-Longformer embeddings and MLP.""" embeddings = extract_features(texts, longformer_model, longformer_tokenizer, device) embeddings_pca = pca.transform(embeddings.cpu().numpy()) # Apply PCA inputs = torch.FloatTensor(embeddings_pca).to(device) with torch.no_grad(): logits = best_mlp_model(inputs) probabilities = torch.sigmoid(logits).cpu().numpy() return probabilities # Extract Medical Entities with st.spinner("๐Ÿ” Identifying medical entities..."): extracted_data["Extracted Medications"] = extract_entities( extracted_data.get("Discharge Medications", ""), ner_pipe, "Medication" ) extracted_data["Extracted Diseases"] = extract_entities( extracted_data.get("Discharge Diagnosis", ""), ner_pipe, "Disease_disorder" ) extracted_data["Extracted Diseases (Past Medical History)"] = extract_entities( extracted_data.get("Past Medical History", ""), ner_pipe, "Disease_disorder" ) extracted_data["Extracted Diseases (History of Present Illness)"] = extract_entities( extracted_data.get("History of Present Illness", ""), ner_pipe, "Disease_disorder" ) # Extraรงรฃo de sintomas agora inclui "History of Present Illness" extracted_data["Extracted Symptoms"] = extract_entities( extracted_data.get("Review of Systems", "") + " " + extracted_data.get("History of Present Illness", ""), ner_pipe, "Sign_symptom" ) def clean_entities(entities): """Reconstruct fragmented tokens and remove duplicates.""" cleaned = [] temp = "" for entity in entities: if entity.startswith("##"): # Fragmented token temp += entity.replace("##", "") else: if temp: cleaned.append(temp) # Add the reconstructed token temp = entity if temp: cleaned.append(temp) # Add the last processed token # Filter out irrelevant short words and special characters cleaned = [word for word in cleaned if len(word) > 2 and not re.match(r"^[\W_]+$", word)] return sorted(set(cleaned)) # Remove duplicates and sort # Clean extracted diseases and symptoms diseases_cleaned = clean_entities( extracted_data.get("Extracted Diseases", []) + extracted_data.get("Extracted Diseases (Past Medical History)", []) + extracted_data.get("Extracted Diseases (History of Present Illness)", []) ) # Clean and reconstruct medication names medications_cleaned = clean_entities(extracted_data.get("Extracted Medications", [])) # Store cleaned data in the main dictionary extracted_data["Extracted Medications Cleaned"] = medications_cleaned symptoms_cleaned = clean_entities(extracted_data.get("Extracted Symptoms", [])) # Display extracted entities def display_list(title, items, icon="๐Ÿ“Œ"): """Display extracted medical entities in an expandable list.""" with st.expander(f"**{title} ({len(items)})**"): if items: for item in items: st.markdown(f"- {icon} **{item}**") else: st.markdown("_No information available._") # Layout Header st.markdown("## ๐Ÿฅ **Patient Medical Analysis**") st.markdown("---") # Creating columns for metrics col1, col2, col3 = st.columns(3) # Medications Metrics num_medications = len(medications_cleaned ) col1.metric(label="๐Ÿ’Š Total Medications", value=num_medications) # Diseases Metrics num_diseases = len(diseases_cleaned) col2.metric(label="๐Ÿฆ  Total Diseases", value=num_diseases) # Symptoms Metrics num_symptoms = len(symptoms_cleaned) col3.metric(label="๐Ÿค’ Total Symptoms", value=num_symptoms) st.markdown("---") # Organizing lists in two columns col1, col2 = st.columns(2) # Display Medications List with col1: st.markdown("### ๐Ÿ’Š **Medications**") display_list("Medication List", medications_cleaned , icon="๐Ÿ’Š") # Display Diseases List with col2: st.markdown("### ๐Ÿฆ  **Diseases**") display_list("Disease List", diseases_cleaned, icon="๐Ÿฆ ") # Symptoms Section st.markdown("### ๐Ÿค’ **Symptoms**") display_list("Symptoms List", symptoms_cleaned, icon="๐Ÿค’") st.markdown("---") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer") # Functions for token count and truncation def count_tokens(text): tokens = tokenizer.tokenize(text) return len(tokens) def trunced_text(nr): return 1 if nr > 4096 else 0 # Dictionary of diseases with synonyms (matching capitalization in the image) disease_synonyms = { "Pneumonia": ["pneumonia", "pneumonitis"], "Diabetes": ["diabetes", "diabetes mellitus", "dm"], "CHF": ["CHF", "congestive heart failure", "heart failure"], "Septicemia": ["septicemia", "sepsis", "blood infection"], "Cirrhosis": ["cirrhosis", "liver cirrhosis", "hepatic cirrhosis"], "COPD": ["COPD", "chronic obstructive pulmonary disease"], "Renal_Failure": ["renal failure", "kidney failure", "chronic kidney disease", "CKD"] } # Extract relevant fields (assuming extract_fields is defined elsewhere) extracted_data = extract_fields(clinical_note) # Compute token counts number_of_tokens = count_tokens(clinical_note) number_of_tokens_med = count_tokens(extracted_data.get("Discharge Medications", "")) number_of_tokens_dis = count_tokens(extracted_data.get("Discharge Diagnosis", "")) trunced = trunced_text(number_of_tokens) # Convert diagnosis text to lowercase for case-insensitive matching full_diagnosis_text = extracted_data.get("Discharge Diagnosis", "").lower() # Function to check for any synonym in the diagnosis text def check_disease_presence(disease_list, text): return int(any(re.search(rf"\b{synonym}\b", text, re.IGNORECASE) for synonym in disease_list)) # Create binary columns for each disease based on synonyms disease_flags = {disease: check_disease_presence(synonyms, full_diagnosis_text) for disease, synonyms in disease_synonyms.items()} # Count total diseases found disease_flags["total_conditions"] = sum(disease_flags.values()) # Create DataFrame with a single row (matching column names from the image) df = pd.DataFrame([{ 'number_of_tokens_dis': number_of_tokens_dis, 'number_of_tokens': number_of_tokens, 'number_of_tokens_med': number_of_tokens_med, 'diagnostic_count': num_diseases, # Ensuring column name matches 'total_conditions': disease_flags["total_conditions"], # Matching name 'trunced': trunced, **{disease: disease_flags[disease] for disease in disease_synonyms.keys()} # Disease presence flags }]) # Display DataFrame #st.write(df) #load lighGBoost model light_path = '/Users/joaopimenta/Downloads/best_lgbm_model.pkl' light_model = joblib.load(light_path) #st.write("LightGBoost Model loaded sucessfully!") # Ensure df is already created from previous steps # Select only the columns that match the model input model_features = light_model.feature_name_ # Check if all required features are in df missing_features = [feat for feat in model_features if feat not in df.columns] if missing_features: st.write(f"โš ๏ธ Warning: Missing features in df: {missing_features}") # Fill missing columns with 0 (if needed, assuming binary features) for feat in missing_features: df[feat] = 0 # Default to 0 for missing binary disease indicators # Reorder df to match model features exactly df = df[model_features] # Convert df to NumPy array for LightGBM prediction X = df.to_numpy() # Make prediction # Get probability of readmission light_probability = light_model.predict_proba(X)[:, 1] # Get probability for class 1 (readmission) # Armazenar no session_state st.session_state["lightgbm probability"] = light_probability # Output results #st.write(f"๐Ÿ”น Readmission Prediction: {probability}") # Prediction Button if st.button("๐Ÿš€ Predict Readmission"): with st.spinner("๐Ÿ” Extracting embeddings and predicting readmission..."): readmission_prob = predict_readmission([cleaned_note])[0][0] # Compute only once st.session_state["MLP probability"] = readmission_prob prediction = 1 if readmission_prob > 0.5 else 0 # Define prediction value # Display Results st.subheader("๐ŸŽฏ Prediction Results") col1, col2 = st.columns(2) with col1: st.metric(label="๐Ÿงฎ Readmission Probability", value=f"{readmission_prob:.2%}") with col2: if prediction == 1: st.error("โš ๏ธ High Risk of Readmission") else: st.success("โœ… Low Risk of Readmission") # Display Readmission Probability with Centered Styling st.markdown(f"""

๐Ÿ“Š Readmission Probability

{readmission_prob:.2%}

""", unsafe_allow_html=True) elif page == "๐Ÿ”€ Ensemble Prediction": # Load the ensemble model ensemble_model = joblib.load("/Users/joaopimenta/Downloads/best_ensemble_model.pkl") #st.write("โœ… Ensemble Model loaded successfully!") # Define models models = ["XGBoost", "lightgbm", "MLP"] # Retrieve stored probabilities from session state and ensure they are numeric probabilities = [] for model in models: key = f"{model} probability" if key in st.session_state: try: prob = float(st.session_state[key]) probabilities.append(prob) except ValueError: st.error(f"โš ๏ธ Invalid probability value for {model}: {st.session_state[key]}") probabilities.append(None) else: probabilities.append(None) # Ensure all probabilities are valid before proceeding if None not in probabilities: st.write("### ๐Ÿ—ณ๏ธ Voting Process in Progress...") progress_bar = st.progress(0) # Progress bar voting_display = st.empty() # Placeholder for voting animation votes = [] for i, (model, prob) in enumerate(zip(models, probabilities)): time.sleep(1) # Simulate suspense # Simulated blinking effect for _ in range(3): voting_display.markdown(f"โณ {model} is deciding...") time.sleep(0.5) voting_display.markdown("") time.sleep(0.5) # Convert probability to label if prob < 0.33: vote = "๐ŸŸข Low" elif prob < 0.46: vote = "๐ŸŸก Medium" else: vote = "๐Ÿ”ด High" votes.append(vote) voting_display.markdown(f"โœ… **{model} voted: {vote}**") progress_bar.progress((i + 1) / len(models)) time.sleep(1) progress_bar.empty() # Create a DataFrame with numeric probabilities final_df = pd.DataFrame([probabilities], columns=['probs', 'probs_lgb', 'probs_mlp']) final_df = final_df.astype(float) # Ensure all values are float # Fazer a prediรงรฃo final com o ensemble final_probability = ensemble_model.predict_proba(final_df)[:, 1][0] # Probabilidade de classe 1 final_prediction = 1 if final_probability >= 0.25 else 0 # Aplicando threshold de 0.25 # Estilizaรงรฃo do resultado final st.markdown("---") if final_prediction == 1: st.markdown(f"""

๐Ÿšจ Final Prediction: 1 (Readmission Likely)

๐Ÿ” Probability: {final_probability:.2f} (Threshold: 0.25)

""", unsafe_allow_html=True) else: st.markdown(f"""

โœ… Final Prediction: 0 (No Readmission Risk)

๐Ÿ” Probability: {final_probability:.2f} (Threshold: 0.25)

""", unsafe_allow_html=True) # ๐ŸŽจ **Weight Visualization: How Much Each Model Contributed** st.write("### โš–๏ธ Model Contribution to Final Decision") fig, ax = plt.subplots() ax.bar(models, probabilities, color=["blue", "green", "red"]) ax.set_ylabel("Probability") ax.set_title("Model Prediction Probabilities") st.pyplot(fig) # Show detailed voting breakdown st.write("### ๐Ÿ“Š Voting Breakdown:") for model, vote in zip(models, votes): st.write(f"๐Ÿ”น {model}: **{vote}** (Prob: {probabilities[models.index(model)]:.2f})") else: st.warning("โš ๏ธ Some model predictions are missing. Please run all models before voting.")